diff --git a/.travis.yml b/.travis.yml index 2a221e3bd..85a57f6e3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,33 +1,38 @@ # XXX: Precise is already deprecated, new default is Trusty. # https://blog.travis-ci.com/2017-07-11-trusty-as-default-linux-is-coming -dist: precise +dist: focal sudo: true language: c matrix: include: - &test-ubuntu - os: linux +# os: linux compiler: gcc addons: apt: packages: - gfortran +# before_script: &common-before +# - COMMON_FLAGS="DYNAMIC_ARCH=1 TARGET=NEHALEM NUM_THREADS=32" +# script: +# - make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE +# - make -C test $COMMON_FLAGS $BTYPE +# - make -C ctest $COMMON_FLAGS $BTYPE +# - make -C utest $COMMON_FLAGS $BTYPE +# env: +# - TARGET_BOX=LINUX64 +# - BTYPE="BINARY=64" +# +# - <<: *test-ubuntu + os: linux-ppc64le before_script: &common-before - - COMMON_FLAGS="DYNAMIC_ARCH=1 TARGET=NEHALEM NUM_THREADS=32" + - COMMON_FLAGS="DYNAMIC_ARCH=1 TARGET=POWER8 NUM_THREADS=32" script: - make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE - make -C test $COMMON_FLAGS $BTYPE - make -C ctest $COMMON_FLAGS $BTYPE - make -C utest $COMMON_FLAGS $BTYPE - env: - - TARGET_BOX=LINUX64 - - BTYPE="BINARY=64" - - - <<: *test-ubuntu - os: linux-ppc64le - before_script: - - COMMON_FLAGS="DYNAMIC_ARCH=1 TARGET=POWER8 NUM_THREADS=32" env: # for matrix annotation only - TARGET_BOX=PPC64LE_LINUX @@ -55,38 +60,38 @@ matrix: - TARGET_BOX=IBMZ_LINUX - BTYPE="BINARY=64 USE_OPENMP=0 CC=clang" - - <<: *test-ubuntu - env: - - TARGET_BOX=LINUX64 - - BTYPE="BINARY=64 USE_OPENMP=1" - - - <<: *test-ubuntu - env: - - TARGET_BOX=LINUX64 - - BTYPE="BINARY=64 INTERFACE64=1" - - - <<: *test-ubuntu - compiler: clang - env: - - TARGET_BOX=LINUX64 - - BTYPE="BINARY=64 CC=clang" - - - <<: *test-ubuntu - compiler: clang - env: - - TARGET_BOX=LINUX64 - - BTYPE="BINARY=64 INTERFACE64=1 CC=clang" - - - <<: *test-ubuntu - addons: - apt: - packages: - - gcc-multilib - - gfortran-multilib - env: - - TARGET_BOX=LINUX32 - - BTYPE="BINARY=32" - +# - <<: *test-ubuntu +# env: +# - TARGET_BOX=LINUX64 +# - BTYPE="BINARY=64 USE_OPENMP=1" +# +# - <<: *test-ubuntu +# env: +# - TARGET_BOX=LINUX64 +# - BTYPE="BINARY=64 INTERFACE64=1" +# +# - <<: *test-ubuntu +# compiler: clang +# env: +# - TARGET_BOX=LINUX64 +# - BTYPE="BINARY=64 CC=clang" +# +# - <<: *test-ubuntu +# compiler: clang +# env: +# - TARGET_BOX=LINUX64 +# - BTYPE="BINARY=64 INTERFACE64=1 CC=clang" +# +# - <<: *test-ubuntu +# addons: +# apt: +# packages: +# - gcc-multilib +# - gfortran-multilib +# env: +# - TARGET_BOX=LINUX32 +# - BTYPE="BINARY=32" +# - os: linux arch: ppc64le dist: bionic @@ -121,47 +126,47 @@ matrix: # for matrix annotation only - TARGET_BOX=PPC64LE_LINUX_P9 - - os: linux - compiler: gcc - addons: - apt: - packages: - - binutils-mingw-w64-x86-64 - - gcc-mingw-w64-x86-64 - - gfortran-mingw-w64-x86-64 - before_script: *common-before - script: - - travis_wait 45 make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE - env: - - TARGET_BOX=WIN64 - - BTYPE="BINARY=64 HOSTCC=gcc CC=x86_64-w64-mingw32-gcc FC=x86_64-w64-mingw32-gfortran" - +# - os: linux +# compiler: gcc +# addons: +# apt: +# packages: +# - binutils-mingw-w64-x86-64 +# - gcc-mingw-w64-x86-64 +# - gfortran-mingw-w64-x86-64 +# before_script: *common-before +# script: +# - travis_wait 45 make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE +# env: +# - TARGET_BOX=WIN64 +# - BTYPE="BINARY=64 HOSTCC=gcc CC=x86_64-w64-mingw32-gcc FC=x86_64-w64-mingw32-gfortran" +# # Build & test on Alpine Linux inside chroot, i.e. on system with musl libc. # These jobs needs sudo, so Travis runs them on VM-based infrastructure # which is slower than container-based infrastructure used for jobs # that don't require sudo. - - &test-alpine - os: linux - dist: trusty - sudo: true - language: minimal - before_install: - - "wget 'https://raw.githubusercontent.com/alpinelinux/alpine-chroot-install/v0.9.0/alpine-chroot-install' \ - && echo 'e5dfbbdc0c4b3363b99334510976c86bfa6cb251 alpine-chroot-install' | sha1sum -c || exit 1" - - alpine() { /alpine/enter-chroot -u "$USER" "$@"; } - install: - - sudo sh alpine-chroot-install -p 'build-base gfortran perl linux-headers' - before_script: *common-before - script: - # XXX: Disable some warnings for now to avoid exceeding Travis limit for log size. - - alpine make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE - CFLAGS="-Wno-misleading-indentation -Wno-sign-conversion -Wno-incompatible-pointer-types" - - alpine make -C test $COMMON_FLAGS $BTYPE - - alpine make -C ctest $COMMON_FLAGS $BTYPE - - alpine make -C utest $COMMON_FLAGS $BTYPE - env: - - TARGET_BOX=LINUX64_MUSL - - BTYPE="BINARY=64" + # - &test-alpine + # os: linux + # dist: trusty + # sudo: true + # language: minimal + # before_install: + # - "wget 'https://raw.githubusercontent.com/alpinelinux/alpine-chroot-install/v0.9.0/alpine-chroot-install' \ + # && echo 'e5dfbbdc0c4b3363b99334510976c86bfa6cb251 alpine-chroot-install' | sha1sum -c || exit 1" + # - alpine() { /alpine/enter-chroot -u "$USER" "$@"; } + # install: + # - sudo sh alpine-chroot-install -p 'build-base gfortran perl linux-headers' + # before_script: *common-before + # script: + # # XXX: Disable some warnings for now to avoid exceeding Travis limit for log size. + # - alpine make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE + # CFLAGS="-Wno-misleading-indentation -Wno-sign-conversion -Wno-incompatible-pointer-types" + # - alpine make -C test $COMMON_FLAGS $BTYPE + # - alpine make -C ctest $COMMON_FLAGS $BTYPE + # - alpine make -C utest $COMMON_FLAGS $BTYPE + # env: + # - TARGET_BOX=LINUX64_MUSL + # - BTYPE="BINARY=64" # XXX: This job segfaults in TESTS OF THE COMPLEX LEVEL 3 BLAS, # but only on Travis CI, cannot reproduce it elsewhere. @@ -171,98 +176,98 @@ matrix: # - TARGET_BOX=LINUX64_MUSL # - BTYPE="BINARY=64 USE_OPENMP=1" - - <<: *test-alpine - env: - - TARGET_BOX=LINUX64_MUSL - - BTYPE="BINARY=64 INTERFACE64=1" +# - <<: *test-alpine +# env: +# - TARGET_BOX=LINUX64_MUSL +# - BTYPE="BINARY=64 INTERFACE64=1" +# +# # Build with the same flags as Alpine do in OpenBLAS package. +# - <<: *test-alpine +# env: +# - TARGET_BOX=LINUX64_MUSL +# - BTYPE="BINARY=64 NO_AFFINITY=1 USE_OPENMP=0 NO_LAPACK=0 TARGET=CORE2" - # Build with the same flags as Alpine do in OpenBLAS package. - - <<: *test-alpine - env: - - TARGET_BOX=LINUX64_MUSL - - BTYPE="BINARY=64 NO_AFFINITY=1 USE_OPENMP=0 NO_LAPACK=0 TARGET=CORE2" +# - &test-cmake +# os: linux +# compiler: clang +# addons: +# apt: +# packages: +# - gfortran +# - cmake +# dist: trusty +# sudo: true +# before_script: +# - COMMON_ARGS="-DTARGET=NEHALEM -DNUM_THREADS=32" +# script: +# - mkdir build +# - CONFIG=Release +# - cmake -Bbuild -H. $CMAKE_ARGS $COMMON_ARGS -DCMAKE_BUILD_TYPE=$CONFIG +# - cmake --build build --config $CONFIG -- -j2 +# env: +# - CMAKE=1 +# - <<: *test-cmake +# env: +# - CMAKE=1 CMAKE_ARGS="-DNOFORTRAN=1" +# - <<: *test-cmake +# compiler: gcc +# env: +# - CMAKE=1 - - &test-cmake - os: linux - compiler: clang - addons: - apt: - packages: - - gfortran - - cmake - dist: trusty - sudo: true - before_script: - - COMMON_ARGS="-DTARGET=NEHALEM -DNUM_THREADS=32" - script: - - mkdir build - - CONFIG=Release - - cmake -Bbuild -H. $CMAKE_ARGS $COMMON_ARGS -DCMAKE_BUILD_TYPE=$CONFIG - - cmake --build build --config $CONFIG -- -j2 - env: - - CMAKE=1 - - <<: *test-cmake - env: - - CMAKE=1 CMAKE_ARGS="-DNOFORTRAN=1" - - <<: *test-cmake - compiler: gcc - env: - - CMAKE=1 - - - &test-macos - os: osx - osx_image: xcode11.5 - before_script: - - COMMON_FLAGS="DYNAMIC_ARCH=1 NUM_THREADS=32" - script: - - travis_wait 45 make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE - env: - - BTYPE="TARGET=NEHALEM BINARY=64 INTERFACE64=1 FC=gfortran-9" - - - <<: *test-macos - osx_image: xcode12 - before_script: - - COMMON_FLAGS="DYNAMIC_ARCH=1 NUM_THREADS=32" - - brew update - script: - - travis_wait 45 make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE - env: - - BTYPE="TARGET=HASWELL USE_OPENMP=1 BINARY=64 INTERFACE64=1 CC=gcc-10 FC=gfortran-10" - - - <<: *test-macos - osx_image: xcode12 - before_script: - - COMMON_FLAGS="DYNAMIC_ARCH=1 NUM_THREADS=32" - - brew update - script: - - travis_wait 45 make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE - env: - - BTYPE="TARGET=NEHALEM BINARY=64 INTERFACE64=1 FC=gfortran-10" +# - &test-macos +# os: osx +# osx_image: xcode11.5 +# before_script: +# - COMMON_FLAGS="DYNAMIC_ARCH=1 NUM_THREADS=32" +# script: +# - travis_wait 45 make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE +# env: +# - BTYPE="TARGET=NEHALEM BINARY=64 INTERFACE64=1 FC=gfortran-9" +# +# - <<: *test-macos +# osx_image: xcode12 +# before_script: +# - COMMON_FLAGS="DYNAMIC_ARCH=1 NUM_THREADS=32" +# - brew update +# script: +# - travis_wait 45 make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE +# env: +# - BTYPE="TARGET=HASWELL USE_OPENMP=1 BINARY=64 INTERFACE64=1 CC=gcc-10 FC=gfortran-10" +# +# - <<: *test-macos +# osx_image: xcode12 +# before_script: +# - COMMON_FLAGS="DYNAMIC_ARCH=1 NUM_THREADS=32" +# - brew update +# script: +# - travis_wait 45 make QUIET_MAKE=1 $COMMON_FLAGS $BTYPE +# env: +# - BTYPE="TARGET=NEHALEM BINARY=64 INTERFACE64=1 FC=gfortran-10" # - <<: *test-macos # osx_image: xcode10 # env: # - BTYPE="TARGET=NEHALEM BINARY=32 NOFORTRAN=1" - - <<: *test-macos - osx_image: xcode11.5 - before_script: - - COMMON_FLAGS="DYNAMIC_ARCH=1 NUM_THREADS=32" - - brew update - env: +# - <<: *test-macos +# osx_image: xcode11.5 +# before_script: +# - COMMON_FLAGS="DYNAMIC_ARCH=1 NUM_THREADS=32" +# - brew update +# env: # - CC="/Applications/Xcode-10.1.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang" # - CFLAGS="-O2 -Wno-macro-redefined -isysroot /Applications/Xcode-10.1.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS12.1.sdk -arch arm64 -miphoneos-version-min=10.0" - - CC="/Applications/Xcode-11.5.GM.Seed.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang" - - CFLAGS="-O2 -Wno-macro-redefined -isysroot /Applications/Xcode-11.5.GM.Seed.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS13.5.sdk -arch arm64 -miphoneos-version-min=10.0" - - BTYPE="TARGET=ARMV8 BINARY=64 HOSTCC=clang NOFORTRAN=1" - - <<: *test-macos - osx_image: xcode11.5 - env: -# - CC="/Applications/Xcode-10.1.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang" -# - CFLAGS="-O2 -mno-thumb -Wno-macro-redefined -isysroot /Applications/Xcode-10.1.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS12.1.sdk -arch armv7 -miphoneos-version-min=5.1" - - CC="/Applications/Xcode-11.5.GM.Seed.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang" - - CFLAGS="-O2 -mno-thumb -Wno-macro-redefined -isysroot /Applications/Xcode-11.5.GM.Seed.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS13.5.sdk -arch armv7 -miphoneos-version-min=5.1" - - BTYPE="TARGET=ARMV7 HOSTCC=clang NOFORTRAN=1" +# - CC="/Applications/Xcode-11.5.GM.Seed.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang" +# - CFLAGS="-O2 -Wno-macro-redefined -isysroot /Applications/Xcode-11.5.GM.Seed.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS13.5.sdk -arch arm64 -miphoneos-version-min=10.0" +# - BTYPE="TARGET=ARMV8 BINARY=64 HOSTCC=clang NOFORTRAN=1" +# - <<: *test-macos +# osx_image: xcode11.5 +# env: +## - CC="/Applications/Xcode-10.1.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang" +## - CFLAGS="-O2 -mno-thumb -Wno-macro-redefined -isysroot /Applications/Xcode-10.1.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS12.1.sdk -arch armv7 -miphoneos-version-min=5.1" +# - CC="/Applications/Xcode-11.5.GM.Seed.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang" +# - CFLAGS="-O2 -mno-thumb -Wno-macro-redefined -isysroot /Applications/Xcode-11.5.GM.Seed.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS13.5.sdk -arch armv7 -miphoneos-version-min=5.1" +# - BTYPE="TARGET=ARMV7 HOSTCC=clang NOFORTRAN=1" - &test-graviton2 os: linux diff --git a/CMakeLists.txt b/CMakeLists.txt index 1ea2e551c..e63f7e04c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -132,7 +132,7 @@ endif () if (BUILD_BFLOAT16) message(STATUS "Building Half Precision") - list(APPEND FLOAT_TYPES "BFLOAT16") # defines nothing + # list(APPEND FLOAT_TYPES "BFLOAT16") # defines nothing endif () if (NOT DEFINED CORE OR "${CORE}" STREQUAL "UNKNOWN") diff --git a/Changelog.txt b/Changelog.txt index ee0484e2b..59fe1d45e 100644 --- a/Changelog.txt +++ b/Changelog.txt @@ -1,4 +1,47 @@ OpenBLAS ChangeLog +==================================================================== +Version 0.3.18 + 02-Oct-2021 + +general: + - when the build-time number of preconfigured threads is exceeded + at runtime (typically by an external program calling BLAS functions + from a larger number of threads in parallel), OpenBLAS will now + allocate an auxiliary control structure for up to 512 additional + threads instead of aborting + - added support for Loongson's LoongArch64 cpu architecture + - fixed building OpenBLAS with CMAKE and -DBUILD_BFLOAT16=ON + - added support for building OpenBLAS as a CMAKE subproject + - added support for building for Windows/ARM64 targets with clang + - improved support for building with the IBM xlf compiler + - imported Reference-LAPACK PR 625 (out-of-bounds reads in ?LARRV) + - imported Reference-LAPACK PR 597 for testsuite compatibility with + LLVM's libomp + +x86_64: + - added SkylakeX S/DGEMM kernels for small problem sizes (M*N*K<=1000000) + - added optimized SBGEMM for Intel Cooper Lake + - reinstated the performance patch for AVX512 SGEMV_T with a proper fix + - added a workaround for a gcc11 tree-vectorizer bug that caused spurious + failures in the test programs for complex BLAS3 when compiling at -O3 + (the default for cmake "release" builds) + - added support for runtime cpu count detection under Haiku OS + - worked around a long-standing miscompilation issue of the Haswell DGEMV_T + kernel with gcc that could produce NaN output in some corner cases + +POWER: + - improved performance of DASUM on POWER10 + +ARMV8: + - fixed crashes (use of reserved register x18) on Apple M1 under OSX + - fixed building with gcc releases earlier than 5.1 + +MIPS: + - fixed building under BSD + +MIPS64: + - fixed building under BSD + ==================================================================== Version 0.3.17 15-Jul-2021 diff --git a/Makefile b/Makefile index 555d1c467..49fd57ff2 100644 --- a/Makefile +++ b/Makefile @@ -269,7 +269,7 @@ prof_lapack : lapack_prebuild lapack_prebuild : ifeq ($(NOFORTRAN), $(filter 0,$(NOFORTRAN))) -@echo "FC = $(FC)" > $(NETLIB_LAPACK_DIR)/make.inc - -@echo "FFLAGS = $(LAPACK_FFLAGS)" >> $(NETLIB_LAPACK_DIR)/make.inc + -@echo "override FFLAGS = $(LAPACK_FFLAGS)" >> $(NETLIB_LAPACK_DIR)/make.inc -@echo "FFLAGS_DRV = $(LAPACK_FFLAGS)" >> $(NETLIB_LAPACK_DIR)/make.inc -@echo "POPTS = $(LAPACK_FPFLAGS)" >> $(NETLIB_LAPACK_DIR)/make.inc -@echo "FFLAGS_NOOPT = -O0 $(LAPACK_NOOPT)" >> $(NETLIB_LAPACK_DIR)/make.inc diff --git a/Makefile.arm64 b/Makefile.arm64 index c23a0876e..2656a17f9 100644 --- a/Makefile.arm64 +++ b/Makefile.arm64 @@ -1,4 +1,15 @@ ifneq ($(C_COMPILER), PGI) + +ifneq ($(GCCVERSIONGT4), 1) +CCOMMON_OPT += -march=armv8-a +ifneq ($(F_COMPILER), NAG) +FCOMMON_OPT += -march=armv8-a +endif + + +else + + ifeq ($(CORE), ARMV8) CCOMMON_OPT += -march=armv8-a ifneq ($(F_COMPILER), NAG) @@ -138,4 +149,7 @@ FCOMMON_OPT += -march=armv8-a -mtune=emag endif endif endif + endif + +endif \ No newline at end of file diff --git a/Makefile.loongarch64 b/Makefile.loongarch64 new file mode 100644 index 000000000..05ea9c679 --- /dev/null +++ b/Makefile.loongarch64 @@ -0,0 +1,3 @@ +ifdef BINARY64 +else +endif diff --git a/Makefile.power b/Makefile.power index 946f55232..28a0bae08 100644 --- a/Makefile.power +++ b/Makefile.power @@ -12,9 +12,13 @@ endif ifeq ($(CORE), POWER10) ifneq ($(C_COMPILER), PGI) CCOMMON_OPT += -Ofast -mcpu=power10 -mtune=power10 -mvsx -fno-fast-math +ifeq ($(F_COMPILER), IBM) +FCOMMON_OPT += -O2 -qrecur -qnosave +else FCOMMON_OPT += -O2 -frecursive -mcpu=power10 -mtune=power10 -fno-fast-math endif endif +endif ifeq ($(CORE), POWER9) ifneq ($(C_COMPILER), PGI) @@ -33,7 +37,11 @@ else CCOMMON_OPT += -fast -Mvect=simd -Mcache_align endif ifneq ($(F_COMPILER), PGI) +ifeq ($(F_COMPILER), IBM) +FCOMMON_OPT += -O2 -qrecur -qnosave +else FCOMMON_OPT += -O2 -frecursive -fno-fast-math +endif ifeq ($(C_COMPILER), GCC) ifneq ($(GCCVERSIONGT4), 1) $(warning your compiler is too old to fully support POWER9, getting a newer version of gcc is recommended) @@ -57,7 +65,11 @@ CCOMMON_OPT += -fast -Mvect=simd -Mcache_align endif ifneq ($(F_COMPILER), PGI) ifeq ($(OSNAME), AIX) +ifeq ($(F_COMPILER), IBM) +FCOMMON_OPT += -O2 -qrecur -qnosave +else FCOMMON_OPT += -O1 -frecursive -mcpu=power8 -mtune=power8 -fno-fast-math +endif else FCOMMON_OPT += -O2 -frecursive -mcpu=power8 -mtune=power8 -fno-fast-math endif diff --git a/Makefile.rule b/Makefile.rule index 2e0980fa9..7c04a3101 100644 --- a/Makefile.rule +++ b/Makefile.rule @@ -3,7 +3,7 @@ # # This library's version -VERSION = 0.3.17 +VERSION = 0.3.17.dev # If you set the suffix, the library name will be libopenblas_$(LIBNAMESUFFIX).a # and libopenblas_$(LIBNAMESUFFIX).so. Meanwhile, the soname in shared library diff --git a/Makefile.system b/Makefile.system index bb8c60e91..150dbef50 100644 --- a/Makefile.system +++ b/Makefile.system @@ -33,6 +33,10 @@ else ifeq ($(ARCH), armv7) override ARCH=arm else ifeq ($(ARCH), aarch64) override ARCH=arm64 +else ifeq ($(ARCH), mipsel) +override ARCH=mips +else ifeq ($(ARCH), mips64el) +override ARCH=mips64 else ifeq ($(ARCH), zarch) override ARCH=zarch endif @@ -244,6 +248,14 @@ else ONLY_CBLAS = 0 endif +#For small matrix optimization +ifeq ($(ARCH), x86_64) +SMALL_MATRIX_OPT = 1 +endif +ifeq ($(SMALL_MATRIX_OPT), 1) +CCOMMON_OPT += -DSMALL_MATRIX_OPT +endif + # This operation is expensive, so execution should be once. ifndef GOTOBLAS_MAKEFILE export GOTOBLAS_MAKEFILE = 1 @@ -780,6 +792,11 @@ NO_BINARY_MODE = 1 BINARY_DEFINED = 1 endif +ifeq ($(ARCH), loongarch64) +NO_BINARY_MODE = 1 +BINARY_DEFINED = 1 +endif + # # C Compiler dependent settings @@ -850,6 +867,13 @@ ifeq ($(OSNAME), AIX) BINARY_DEFINED = 1 endif +ifeq ($(ARCH), loongarch64) +ifeq ($(CORE), LOONGSON3R5) +CCOMMON_OPT += -march=loongarch64 -mabi=lp64 +FCOMMON_OPT += -march=loongarch64 -mabi=lp64 +endif +endif + endif ifndef BINARY_DEFINED diff --git a/README.md b/README.md index d7e0d60a7..6ce85e08e 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Join the chat at https://gitter.im/xianyi/OpenBLAS](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/xianyi/OpenBLAS?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -Travis CI: [![Build Status](https://travis-ci.org/xianyi/OpenBLAS.svg?branch=develop)](https://travis-ci.org/xianyi/OpenBLAS) +Travis CI: [![Build Status](https://travis-ci.com/xianyi/OpenBLAS.svg?branch=develop)](https://travis-ci.com/xianyi/OpenBLAS) AppVeyor: [![Build status](https://ci.appveyor.com/api/projects/status/09sohd35n8nkkx64/branch/develop?svg=true)](https://ci.appveyor.com/project/xianyi/openblas/branch/develop) @@ -128,6 +128,7 @@ Please read `GotoBLAS_01Readme.txt` for older CPU models already supported by th - **Intel Sandy Bridge**: Optimized Level-3 and Level-2 BLAS with AVX on x86-64. - **Intel Haswell**: Optimized Level-3 and Level-2 BLAS with AVX2 and FMA on x86-64. - **Intel Skylake-X**: Optimized Level-3 and Level-2 BLAS with AVX512 and FMA on x86-64. +- **Intel Cooper Lake**: as Skylake-X with improved BFLOAT16 support. - **AMD Bobcat**: Used GotoBLAS2 Barcelona codes. - **AMD Bulldozer**: x86-64 ?GEMM FMA4 kernels. (Thanks to Werner Saar) - **AMD PILEDRIVER**: Uses Bulldozer codes with some optimizations. @@ -153,6 +154,7 @@ Please read `GotoBLAS_01Readme.txt` for older CPU models already supported by th - **ARMv8**: Basic ARMV8 with small caches, optimized Level-3 and Level-2 BLAS - **Cortex-A53**: same as ARMV8 (different cpu specifications) +- **Cortex-A55**: same as ARMV8 (different cpu specifications) - **Cortex A57**: Optimized Level-3 and Level-2 functions - **Cortex A72**: same as A57 ( different cpu specifications) - **Cortex A73**: same as A57 (different cpu specifications) @@ -178,10 +180,11 @@ Please read `GotoBLAS_01Readme.txt` for older CPU models already supported by th #### RISC-V -- **C910V**: Optimized Leve-3 BLAS (real) and Level-1,2 by RISC-V Vector extension 0.7.1. +- **C910V**: Optimized Level-3 BLAS (real) and Level-1,2 by RISC-V Vector extension 0.7.1. ```sh make HOSTCC=gcc TARGET=C910V CC=riscv64-unknown-linux-gnu-gcc FC=riscv64-unknown-linux-gnu-gfortran ``` + (also known to work on C906) ### Support for multiple targets in a single library diff --git a/TargetList.txt b/TargetList.txt index f93a629d8..963545cdd 100644 --- a/TargetList.txt +++ b/TargetList.txt @@ -110,3 +110,5 @@ Z14 RISCV64_GENERIC C910V +11.LOONGARCH64: +LOONGSON3R5 diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 889b920e3..f9e79018b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -19,7 +19,7 @@ jobs: # of gcc / glibc - job: manylinux1_gcc pool: - vmImage: 'ubuntu-16.04' + vmImage: 'ubuntu-latest' steps: - script: | echo "FROM quay.io/pypa/manylinux1_x86_64 @@ -35,7 +35,7 @@ jobs: displayName: Run manylinux1 docker build - job: Intel_SDE_skx pool: - vmImage: 'ubuntu-16.04' + vmImage: 'ubuntu-latest' steps: - script: | # at the time of writing the available Azure Ubuntu vm image @@ -83,6 +83,8 @@ jobs: - script: | brew update make TARGET=CORE2 DYNAMIC_ARCH=1 USE_OPENMP=1 INTERFACE64=1 CC=gcc-10 FC=gfortran-10 + make TARGET=CORE2 DYNAMIC_ARCH=1 USE_OPENMP=1 INTERFACE64=1 CC=gcc-10 FC=gfortran-10 PREFIX=../blasinst install + ls -lR ../blasinst - job: OSX_GCC_Nothreads pool: @@ -104,6 +106,38 @@ jobs: brew install llvm libomp make TARGET=CORE2 USE_OPENMP=1 INTERFACE64=1 DYNAMIC_ARCH=1 CC=/usr/local/opt/llvm/bin/clang FC=gfortran-10 +- job: OSX_OpenMP_Clang_cmake + pool: + vmImage: 'macOS-10.15' + variables: + LD_LIBRARY_PATH: /usr/local/opt/llvm/lib + LIBRARY_PATH: /usr/local/opt/llvm/lib + steps: + - script: | + brew update + brew install llvm libomp + mkdir build + cd build + cmake -DTARGET=CORE2 -DUSE_OPENMP=1 -DINTERFACE64=1 -DDYNAMIC_ARCH=1 -DCMAKE_C_COMPILER=/usr/local/opt/llvm/bin/clang -DNOFORTRAN=1 -DNO_AVX512=1 .. + make + ctest + +- job: OSX_OpenMP_Clang_gf_cmake + pool: + vmImage: 'macOS-10.15' + variables: + LD_LIBRARY_PATH: /usr/local/opt/llvm/lib + LIBRARY_PATH: /usr/local/opt/llvm/lib + steps: + - script: | + brew update + brew install llvm libomp + mkdir build + cd build + cmake -DTARGET=CORE2 -DUSE_OPENMP=1 -DINTERFACE64=1 -DDYNAMIC_ARCH=1 -DCMAKE_C_COMPILER=/usr/local/opt/llvm/bin/clang -DNO_AVX512=1 .. + make + ctest + - job: OSX_Ifort_Clang pool: vmImage: 'macOS-10.15' @@ -146,14 +180,35 @@ jobs: brew install --cask android-ndk export ANDROID_NDK_HOME=/usr/local/share/android-ndk make TARGET=ARMV7 ONLY_CBLAS=1 CC=$ANDROID_NDK_HOME/toolchains/llvm/prebuilt/darwin-x86_64/bin/armv7a-linux-androideabi21-clang AR=$ANDROID_NDK_HOME/toolchains/llvm/prebuilt/darwin-x86_64/bin/arm-linux-androideabi-ar HOSTCC=gcc ARM_SOFTFP_ABI=1 -j4 - + +- job: OSX_IOS_ARMV8 + pool: + vmImage: 'macOS-10.15' + variables: + CC: /Applications/Xcode_12.4.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang + CFLAGS: -O2 -Wno-macro-redefined -isysroot /Applications/Xcode_12.4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS14.4.sdk -arch arm64 -miphoneos-version-min=10.0 + steps: + - script: | + make TARGET=ARMV8 DYNAMIC_ARCH=1 NUM_THREADS=32 HOSTCC=clang NOFORTRAN=1 + +- job: OSX_IOS_ARMV7 + pool: + vmImage: 'macOS-10.15' + variables: + CC: /Applications/Xcode_12.4.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/clang + CFLAGS: -O2 -mno-thumb -Wno-macro-redefined -isysroot /Applications/Xcode_12.4.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS14.4.sdk -arch armv7 -miphoneos-version-min=5.1 + steps: + - script: | + make TARGET=ARMV7 DYNAMIC_ARCH=1 NUM_THREADS=32 HOSTCC=clang NOFORTRAN=1 + - job: ALPINE_MUSL pool: vmImage: 'ubuntu-latest' steps: - script: | - wget 'https://raw.githubusercontent.com/alpinelinux/alpine-chroot-install/v0.9.0/alpine-chroot-install' \ - && echo 'e5dfbbdc0c4b3363b99334510976c86bfa6cb251 alpine-chroot-install' | sha1sum -c || exit 1 + wget https://raw.githubusercontent.com/alpinelinux/alpine-chroot-install/v0.13.1/alpine-chroot-install \ + && echo '7c7e3fa378e69aecc7f5f01bbc759e5f0a9d9b74 alpine-chroot-install' | sha1sum -c \ + || exit 1 alpine() { /alpine/enter-chroot -u "$USER" "$@"; } sudo sh alpine-chroot-install -p 'build-base gfortran perl linux-headers sudo' alpine make DYNAMIC_ARCH=1 BINARY=64 diff --git a/c_check b/c_check index e24943a29..030f5e632 100644 --- a/c_check +++ b/c_check @@ -82,18 +82,19 @@ $os = Interix if ($data =~ /OS_INTERIX/); $os = Android if ($data =~ /OS_ANDROID/); $os = Haiku if ($data =~ /OS_HAIKU/); -$architecture = x86 if ($data =~ /ARCH_X86/); -$architecture = x86_64 if ($data =~ /ARCH_X86_64/); -$architecture = power if ($data =~ /ARCH_POWER/); -$architecture = mips if ($data =~ /ARCH_MIPS/); -$architecture = mips64 if ($data =~ /ARCH_MIPS64/); -$architecture = alpha if ($data =~ /ARCH_ALPHA/); -$architecture = sparc if ($data =~ /ARCH_SPARC/); -$architecture = ia64 if ($data =~ /ARCH_IA64/); -$architecture = arm if ($data =~ /ARCH_ARM/); -$architecture = arm64 if ($data =~ /ARCH_ARM64/); -$architecture = zarch if ($data =~ /ARCH_ZARCH/); -$architecture = riscv64 if ($data =~ /ARCH_RISCV64/); +$architecture = x86 if ($data =~ /ARCH_X86/); +$architecture = x86_64 if ($data =~ /ARCH_X86_64/); +$architecture = power if ($data =~ /ARCH_POWER/); +$architecture = mips if ($data =~ /ARCH_MIPS/); +$architecture = mips64 if ($data =~ /ARCH_MIPS64/); +$architecture = alpha if ($data =~ /ARCH_ALPHA/); +$architecture = sparc if ($data =~ /ARCH_SPARC/); +$architecture = ia64 if ($data =~ /ARCH_IA64/); +$architecture = arm if ($data =~ /ARCH_ARM/); +$architecture = arm64 if ($data =~ /ARCH_ARM64/); +$architecture = zarch if ($data =~ /ARCH_ZARCH/); +$architecture = riscv64 if ($data =~ /ARCH_RISCV64/); +$architecture = loongarch64 if ($data =~ /ARCH_LOONGARCH64/); $defined = 0; @@ -143,6 +144,11 @@ if ($architecture eq "riscv64") { $binary = 64; } +if ($architecture eq "loongarch64") { + $defined = 1; + $binary = 64; +} + if ($compiler eq "PGI") { $compiler_name .= " -tp p7" if ($binary eq "32"); $compiler_name .= " -tp p7-64" if ($binary eq "64"); @@ -215,17 +221,18 @@ if (($architecture eq "mips") || ($architecture eq "mips64")) { } } -$architecture = x86 if ($data =~ /ARCH_X86/); -$architecture = x86_64 if ($data =~ /ARCH_X86_64/); -$architecture = power if ($data =~ /ARCH_POWER/); -$architecture = mips if ($data =~ /ARCH_MIPS/); -$architecture = mips64 if ($data =~ /ARCH_MIPS64/); -$architecture = alpha if ($data =~ /ARCH_ALPHA/); -$architecture = sparc if ($data =~ /ARCH_SPARC/); -$architecture = ia64 if ($data =~ /ARCH_IA64/); -$architecture = arm if ($data =~ /ARCH_ARM/); -$architecture = arm64 if ($data =~ /ARCH_ARM64/); -$architecture = zarch if ($data =~ /ARCH_ZARCH/); +$architecture = x86 if ($data =~ /ARCH_X86/); +$architecture = x86_64 if ($data =~ /ARCH_X86_64/); +$architecture = power if ($data =~ /ARCH_POWER/); +$architecture = mips if ($data =~ /ARCH_MIPS/); +$architecture = mips64 if ($data =~ /ARCH_MIPS64/); +$architecture = alpha if ($data =~ /ARCH_ALPHA/); +$architecture = sparc if ($data =~ /ARCH_SPARC/); +$architecture = ia64 if ($data =~ /ARCH_IA64/); +$architecture = arm if ($data =~ /ARCH_ARM/); +$architecture = arm64 if ($data =~ /ARCH_ARM64/); +$architecture = zarch if ($data =~ /ARCH_ZARCH/); +$architecture = loongarch64 if ($data =~ /ARCH_LOONGARCH64/); $binformat = bin32; $binformat = bin64 if ($data =~ /BINARY_64/); diff --git a/cblas.h b/cblas.h index f0220eb99..a5ad25ad7 100644 --- a/cblas.h +++ b/cblas.h @@ -400,6 +400,8 @@ void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy); void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy); +void cblas_sbgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, + OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc); #ifdef __cplusplus } #endif /* __cplusplus */ diff --git a/cmake/arch.cmake b/cmake/arch.cmake index 154e59db6..57ee5a4fb 100644 --- a/cmake/arch.cmake +++ b/cmake/arch.cmake @@ -113,6 +113,10 @@ if (MIPS64) set(NO_BINARY_MODE 1) endif () +if (LOONGARCH64) + set(NO_BINARY_MODE 1) +endif () + if (${ARCH} STREQUAL "alpha") set(NO_BINARY_MODE 1) set(BINARY_DEFINED 1) diff --git a/cmake/cc.cmake b/cmake/cc.cmake index 76952152b..1794b5e5b 100644 --- a/cmake/cc.cmake +++ b/cmake/cc.cmake @@ -29,6 +29,15 @@ if (${CMAKE_C_COMPILER_ID} STREQUAL "GNU" OR ${CMAKE_C_COMPILER_ID} STREQUAL "LS set(FCOMMON_OPT "${FCOMMON_OPT} -march=mips64") endif () + if (LOONGARCH64) + if (BINARY64) + set(CCOMMON_OPT "${CCOMMON_OPT} -mabi=lp64") + else () + set(CCOMMON_OPT "${CCOMMON_OPT} -mabi=lp32") + endif () + set(BINARY_DEFINED 1) + endif () + if (CMAKE_SYSTEM_NAME STREQUAL "AIX") set(BINARY_DEFINED 1) endif () @@ -124,9 +133,9 @@ if (NOT DYNAMIC_ARCH) if (HAVE_AVX) set (CCOMMON_OPT "${CCOMMON_OPT} -mavx") endif () - if (HAVE_FMA3) - set (CCOMMON_OPT "${CCOMMON_OPT} -mfma") - endif () + # if (HAVE_FMA3) + #set (CCOMMON_OPT "${CCOMMON_OPT} -mfma") + #endif () if (HAVE_SSE) set (CCOMMON_OPT "${CCOMMON_OPT} -msse") endif () diff --git a/cmake/fc.cmake b/cmake/fc.cmake index fc1f9bb22..f7aa4c5c9 100644 --- a/cmake/fc.cmake +++ b/cmake/fc.cmake @@ -61,6 +61,13 @@ if (${F_COMPILER} STREQUAL "GFORTRAN") set(FCOMMON_OPT "${FCOMMON_OPT} -mabi=n32") endif () endif () + if (LOONGARCH64) + if (BINARY64) + set(FCOMMON_OPT "${FCOMMON_OPT} -mabi=lp64") + else () + set(FCOMMON_OPT "${FCOMMON_OPT} -mabi=lp32") + endif () + endif () else () if (BINARY64) set(FCOMMON_OPT "${FCOMMON_OPT} -m64") @@ -97,7 +104,7 @@ endif () if (${F_COMPILER} STREQUAL "IBM") set(CCOMMON_OPT "${CCOMMON_OPT} -DF_INTERFACE_IBM") - # FCOMMON_OPT += -qarch=440 + set(FCOMMON_OPT "${FCOMMON_OPT} -qrecur") if (BINARY64) set(FCOMMON_OPT "${FCOMMON_OPT} -q64") if (INTERFACE64) diff --git a/cmake/kernel.cmake b/cmake/kernel.cmake index 0c102bae5..09ca5eb57 100644 --- a/cmake/kernel.cmake +++ b/cmake/kernel.cmake @@ -134,6 +134,8 @@ if (BUILD_BFLOAT16) set(SHSWAPKERNEL ../arm/swap.c) set(TOBF16KERNEL ../x86_64/tobf16.c) set(BF16TOKERNEL ../x86_64/bf16to.c) + set(SBGEMVNKERNEL ../x86_64/sbgemv_n.c) + set(SBGEMVTKERNEL ../x86_64/sbgemv_t.c) endif () endmacro () diff --git a/cmake/system.cmake b/cmake/system.cmake index 34874827c..f56ded966 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -186,11 +186,11 @@ if (DEFINED TARGET) set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -mavx2") endif() endif() - if (DEFINED HAVE_FMA3) - if (NOT NO_AVX2) - set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -mfma") - endif() - endif() + # if (DEFINED HAVE_FMA3) + # if (NOT NO_AVX2) + # set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -mfma") + # endif() + # endif() if (DEFINED HAVE_SSE) set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -msse") endif() @@ -258,6 +258,13 @@ if (NEED_PIC) endif() endif () +if (X86_64) + set(SMALL_MATRIX_OPT TRUE) +endif () +if (SMALL_MATRIX_OPT) + set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") +endif () + if (DYNAMIC_ARCH) if (X86 OR X86_64 OR ARM64 OR PPC) set(CCOMMON_OPT "${CCOMMON_OPT} -DDYNAMIC_ARCH") @@ -462,6 +469,9 @@ endif() if (BUILD_COMPLEX16) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_COMPLEX16") endif() +if (BUILD_BFLOAT16) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DBUILD_BFLOAT16") +endif() if(NOT MSVC) set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} ${CCOMMON_OPT}") endif() diff --git a/cmake/system_check.cmake b/cmake/system_check.cmake index fdc79c8ce..8d0558c0e 100644 --- a/cmake/system_check.cmake +++ b/cmake/system_check.cmake @@ -38,6 +38,8 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ppc.*|power.*|Power.*") set(PPC 1) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "mips64.*") set(MIPS64 1) +elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "loongarch64.*") + set(LOONGARCH64 1) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "amd64.*|x86_64.*|AMD64.*") if (NOT BINARY) if("${CMAKE_SIZEOF_VOID_P}" EQUAL "8") @@ -95,7 +97,7 @@ else() endif () if (NOT BINARY) - if (X86_64 OR ARM64 OR PPC OR MIPS64) + if (X86_64 OR ARM64 OR PPC OR MIPS64 OR LOONGARCH64) set(BINARY 64) else () set(BINARY 32) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 794d73d06..01b489f2a 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -157,31 +157,31 @@ endfunction () # STRING - compiles only the given type (e.g. DOUBLE) function(GenerateNamedObjects sources_in) - if (DEFINED ARGV1) + if (${ARGC} GREATER 1) set(defines_in ${ARGV1}) endif () - if (DEFINED ARGV2 AND NOT "${ARGV2}" STREQUAL "") + if (${ARGC} GREATER 2 AND NOT "${ARGV2}" STREQUAL "") set(name_in ${ARGV2}) # strip off extension for kernel files that pass in the object name. get_filename_component(name_in ${name_in} NAME_WE) endif () - if (DEFINED ARGV3) + if (${ARGC} GREATER 3) set(use_cblas ${ARGV3}) else () set(use_cblas false) endif () - if (DEFINED ARGV4) + if (${ARGC} GREATER 4) set(replace_last_with ${ARGV4}) endif () - if (DEFINED ARGV5) + if (${ARGC} GREATER 5) set(append_with ${ARGV5}) endif () - if (DEFINED ARGV6) + if (${ARGC} GREATER 6) set(no_float_type ${ARGV6}) else () set(no_float_type false) @@ -196,7 +196,7 @@ function(GenerateNamedObjects sources_in) set(real_only false) set(complex_only false) set(mangle_complex_sources false) - if (DEFINED ARGV7 AND NOT "${ARGV7}" STREQUAL "") + if (${ARGC} GREATER 7 AND NOT "${ARGV7}" STREQUAL "") if (${ARGV7} EQUAL 1) set(real_only true) elseif (${ARGV7} EQUAL 2) @@ -311,7 +311,15 @@ function(GenerateNamedObjects sources_in) configure_file(${new_source_file}.tmp ${new_source_file} COPYONLY) file(REMOVE ${new_source_file}.tmp) list(APPEND SRC_LIST_OUT ${new_source_file}) - + message (STATUS ${new_source_file}) + if (DEFINED HAVE_FMA3) + if ( ${new_source_file} MATCHES "(s|d?)rot_k.*c") + set_source_files_properties(${new_source_file} PROPERTIES COMPILE_OPTIONS "-mfma") + endif () + if ( ${new_source_file} MATCHES "dgemv_t_k.*c") + set_source_files_properties(${new_source_file} PROPERTIES COMPILE_OPTIONS "-mfma") + endif () + endif () endforeach () endforeach () @@ -334,17 +342,17 @@ endfunction () function(GenerateCombinationObjects sources_in defines_in absent_codes_in all_defines_in replace_scheme) set(alternate_name_in "") - if (DEFINED ARGV5) + if (${ARGC} GREATER 5) set(alternate_name_in ${ARGV5}) endif () set(no_float_type false) - if (DEFINED ARGV6) + if (${ARGC} GREATER 6) set(no_float_type ${ARGV6}) endif () set(complex_filename_scheme "") - if (DEFINED ARGV7) + if (${ARGC} GREATER 7) set(complex_filename_scheme ${ARGV7}) endif () diff --git a/common.h b/common.h index ac795937c..ff5254a5c 100644 --- a/common.h +++ b/common.h @@ -449,7 +449,7 @@ please https://github.com/xianyi/OpenBLAS/issues/246 #include "common_mips.h" #endif - + #ifdef ARCH_RISCV64 #include "common_riscv64.h" #endif @@ -470,6 +470,10 @@ please https://github.com/xianyi/OpenBLAS/issues/246 #include "common_zarch.h" #endif +#ifdef ARCH_LOONGARCH64 +#include "common_loongarch64.h" +#endif + #ifndef ASSEMBLER #ifdef OS_WINDOWSSTORE typedef char env_var_t[MAX_PATH]; diff --git a/common_arm64.h b/common_arm64.h index 2270ffba7..029e23886 100644 --- a/common_arm64.h +++ b/common_arm64.h @@ -120,7 +120,7 @@ static inline int blas_quickdivide(blasint x, blasint y){ .text ; .p2align 2 ; .global REALNAME ; -#ifndef __APPLE__ +#if !defined(__APPLE__) && !defined(_WIN32) .type REALNAME, %function ; #endif REALNAME: diff --git a/common_c.h b/common_c.h index 40ecf5b8b..6cff610bb 100644 --- a/common_c.h +++ b/common_c.h @@ -232,6 +232,8 @@ #define CGEADD_K cgeadd_k +#define CGEMM_SMALL_MATRIX_PERMIT cgemm_small_matrix_permit + #else #define CAMAX_K gotoblas -> camax_k @@ -426,8 +428,51 @@ #define CGEADD_K gotoblas -> cgeadd_k +#define CGEMM_SMALL_MATRIX_PERMIT gotoblas -> cgemm_small_matrix_permit + #endif +#define CGEMM_SMALL_KERNEL_NN FUNC_OFFSET(cgemm_small_kernel_nn) +#define CGEMM_SMALL_KERNEL_NT FUNC_OFFSET(cgemm_small_kernel_nt) +#define CGEMM_SMALL_KERNEL_NR FUNC_OFFSET(cgemm_small_kernel_nr) +#define CGEMM_SMALL_KERNEL_NC FUNC_OFFSET(cgemm_small_kernel_nc) + +#define CGEMM_SMALL_KERNEL_TN FUNC_OFFSET(cgemm_small_kernel_tn) +#define CGEMM_SMALL_KERNEL_TT FUNC_OFFSET(cgemm_small_kernel_tt) +#define CGEMM_SMALL_KERNEL_TR FUNC_OFFSET(cgemm_small_kernel_tr) +#define CGEMM_SMALL_KERNEL_TC FUNC_OFFSET(cgemm_small_kernel_tc) + +#define CGEMM_SMALL_KERNEL_RN FUNC_OFFSET(cgemm_small_kernel_rn) +#define CGEMM_SMALL_KERNEL_RT FUNC_OFFSET(cgemm_small_kernel_rt) +#define CGEMM_SMALL_KERNEL_RR FUNC_OFFSET(cgemm_small_kernel_rr) +#define CGEMM_SMALL_KERNEL_RC FUNC_OFFSET(cgemm_small_kernel_rc) + +#define CGEMM_SMALL_KERNEL_CN FUNC_OFFSET(cgemm_small_kernel_cn) +#define CGEMM_SMALL_KERNEL_CT FUNC_OFFSET(cgemm_small_kernel_ct) +#define CGEMM_SMALL_KERNEL_CR FUNC_OFFSET(cgemm_small_kernel_cr) +#define CGEMM_SMALL_KERNEL_CC FUNC_OFFSET(cgemm_small_kernel_cc) + +#define CGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(cgemm_small_kernel_b0_nn) +#define CGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(cgemm_small_kernel_b0_nt) +#define CGEMM_SMALL_KERNEL_B0_NR FUNC_OFFSET(cgemm_small_kernel_b0_nr) +#define CGEMM_SMALL_KERNEL_B0_NC FUNC_OFFSET(cgemm_small_kernel_b0_nc) + +#define CGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(cgemm_small_kernel_b0_tn) +#define CGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(cgemm_small_kernel_b0_tt) +#define CGEMM_SMALL_KERNEL_B0_TR FUNC_OFFSET(cgemm_small_kernel_b0_tr) +#define CGEMM_SMALL_KERNEL_B0_TC FUNC_OFFSET(cgemm_small_kernel_b0_tc) + +#define CGEMM_SMALL_KERNEL_B0_RN FUNC_OFFSET(cgemm_small_kernel_b0_rn) +#define CGEMM_SMALL_KERNEL_B0_RT FUNC_OFFSET(cgemm_small_kernel_b0_rt) +#define CGEMM_SMALL_KERNEL_B0_RR FUNC_OFFSET(cgemm_small_kernel_b0_rr) +#define CGEMM_SMALL_KERNEL_B0_RC FUNC_OFFSET(cgemm_small_kernel_b0_rc) + +#define CGEMM_SMALL_KERNEL_B0_CN FUNC_OFFSET(cgemm_small_kernel_b0_cn) +#define CGEMM_SMALL_KERNEL_B0_CT FUNC_OFFSET(cgemm_small_kernel_b0_ct) +#define CGEMM_SMALL_KERNEL_B0_CR FUNC_OFFSET(cgemm_small_kernel_b0_cr) +#define CGEMM_SMALL_KERNEL_B0_CC FUNC_OFFSET(cgemm_small_kernel_b0_cc) + + #define CGEMM_NN cgemm_nn #define CGEMM_CN cgemm_cn #define CGEMM_TN cgemm_tn diff --git a/common_d.h b/common_d.h index 94dc3eea8..6f4bb2ded 100644 --- a/common_d.h +++ b/common_d.h @@ -157,6 +157,8 @@ #define DIMATCOPY_K_RT dimatcopy_k_rt #define DGEADD_K dgeadd_k +#define DGEMM_SMALL_MATRIX_PERMIT dgemm_small_matrix_permit + #else #define DAMAX_K gotoblas -> damax_k @@ -281,8 +283,21 @@ #define DGEADD_K gotoblas -> dgeadd_k +#define DGEMM_SMALL_MATRIX_PERMIT gotoblas -> dgemm_small_matrix_permit + #endif +#define DGEMM_SMALL_KERNEL_NN FUNC_OFFSET(dgemm_small_kernel_nn) +#define DGEMM_SMALL_KERNEL_NT FUNC_OFFSET(dgemm_small_kernel_nt) +#define DGEMM_SMALL_KERNEL_TN FUNC_OFFSET(dgemm_small_kernel_tn) +#define DGEMM_SMALL_KERNEL_TT FUNC_OFFSET(dgemm_small_kernel_tt) + +#define DGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(dgemm_small_kernel_b0_nn) +#define DGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(dgemm_small_kernel_b0_nt) +#define DGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(dgemm_small_kernel_b0_tn) +#define DGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(dgemm_small_kernel_b0_tt) + + #define DGEMM_NN dgemm_nn #define DGEMM_CN dgemm_tn #define DGEMM_TN dgemm_tn diff --git a/common_level3.h b/common_level3.h index c4f9435a9..5080ada10 100644 --- a/common_level3.h +++ b/common_level3.h @@ -515,6 +515,129 @@ int qgemm_kernel(BLASLONG, BLASLONG, BLASLONG, xidouble *, xidouble *, xidouble int qgemm_kernel(BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, xdouble *, xdouble *, BLASLONG); #endif +#ifdef SMALL_MATRIX_OPT +int sbgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); + +int sbgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sbgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sbgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sbgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + +int sgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); + +int sgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); +int sgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + +int dgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double beta); + +int dgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); +int dgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); +int dgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); +int dgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + +int sbgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sbgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sbgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sbgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int sgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int sgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int dgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int dgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + +int cgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha0, float alpha1, float beta0, float beta1); + +int cgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + +int cgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + +int cgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + +int cgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); +int cgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + +int zgemm_small_matrix_permit(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha0, double alpha1, double beta0, double beta1); + +int zgemm_small_kernel_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + +int zgemm_small_kernel_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + +int zgemm_small_kernel_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + +int zgemm_small_kernel_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); +int zgemm_small_kernel_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + +int cgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int cgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int cgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int cgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +int cgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + +int zgemm_small_kernel_b0_nn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_nc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + +int zgemm_small_kernel_b0_tn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_tc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + +int zgemm_small_kernel_b0_rn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rt(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_rc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + +int zgemm_small_kernel_b0_cn(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_ct(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_cr(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +int zgemm_small_kernel_b0_cc(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + +#endif + int cgemm_kernel_n(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); int cgemm_kernel_l(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); int cgemm_kernel_r(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG); diff --git a/common_loongarch64.h b/common_loongarch64.h new file mode 100644 index 000000000..e15539b5f --- /dev/null +++ b/common_loongarch64.h @@ -0,0 +1,199 @@ +/***************************************************************************** +Copyright (c) 2011-2020, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +**********************************************************************************/ + +/*********************************************************************/ +/* Copyright 2009, 2010 The University of Texas at Austin. */ +/* All rights reserved. */ +/* */ +/* Redistribution and use in source and binary forms, with or */ +/* without modification, are permitted provided that the following */ +/* conditions are met: */ +/* */ +/* 1. Redistributions of source code must retain the above */ +/* copyright notice, this list of conditions and the following */ +/* disclaimer. */ +/* */ +/* 2. Redistributions in binary form must reproduce the above */ +/* copyright notice, this list of conditions and the following */ +/* disclaimer in the documentation and/or other materials */ +/* provided with the distribution. */ +/* */ +/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */ +/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */ +/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */ +/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */ +/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */ +/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */ +/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */ +/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */ +/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */ +/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */ +/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */ +/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */ +/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */ +/* POSSIBILITY OF SUCH DAMAGE. */ +/* */ +/* The views and conclusions contained in the software and */ +/* documentation are those of the authors and should not be */ +/* interpreted as representing official policies, either expressed */ +/* or implied, of The University of Texas at Austin. */ +/*********************************************************************/ + +#ifndef COMMON_LOONGARCH64 +#define COMMON_LOONGARCH64 + +#define MB __sync_synchronize() +#define WMB __sync_synchronize() +#define RMB __sync_synchronize() + +#define INLINE inline + +#ifndef ASSEMBLER + +static inline int blas_quickdivide(blasint x, blasint y){ + return x / y; +} + +#ifdef DOUBLE +#define GET_IMAGE(res) __asm__ __volatile__("fmov.d %0, $f2" : "=f"(res) : : "memory") +#else +#define GET_IMAGE(res) __asm__ __volatile__("fmov.s %0, $f2" : "=f"(res) : : "memory") +#endif + +#define GET_IMAGE_CANCEL + +#else + +#ifdef DOUBLE +#define LD fld.d +#define ST fst.d +#define MADD fmadd.d +#define NMADD fnmadd.d +#define MSUB fmsub.d +#define NMSUB fnmsub.d +#define ADD fadd.d +#define SUB fsub.d +#define MUL fmul.d +#define MOV fmov.d +#define CMOVT fsel +#define MTC movgr2fr.d +#define FABS fabs.d +#define CMPEQ fcmp.ceq.d +#define CMPLE fcmp.cle.d +#define CMPLT fcmp.clt.d +#define NEG fneg.d +#else +#define LD fld.s +#define ST fst.s +#define MADD fmadd.s +#define NMADD fnmadd.s +#define MSUB fmsub.s +#define NMSUB fnmsub.s +#define ADD fadd.s +#define SUB fsub.s +#define MUL fmul.s +#define MOV fmov.s +#define CMOVT fsel +#define MTC movgr2fr.w +#define FABS fabs.s +#define CMPEQ fcmp.ceq.s +#define CMPLE fcmp.cle.s +#define CMPLT fcmp.clt.s +#define NEG fneg.s +#endif /* defined(DOUBLE) */ + +#if defined(__64BIT__) && defined(USE64BITINT) +#define LDINT ld.d +#define LDARG ld.d +#define SDARG st.d +#elif defined(__64BIT__) && !defined(USE64BITINT) +#define LDINT ld.w +#define LDARG ld.d +#define SDARG st.d +#else +#define LDINT ld.w +#define LDARG ld.w +#define SDARG st.w +#endif + + +#ifndef F_INTERFACE +#define REALNAME ASMNAME +#else +#define REALNAME ASMFNAME +#endif /* defined(F_INTERFACE) */ + +#if defined(ASSEMBLER) && !defined(NEEDPARAM) + +#define PROLOGUE \ + .text ;\ + .align 5 ;\ + .globl REALNAME ;\ + .type REALNAME, @function ;\ +REALNAME: ;\ + +#if defined(__linux__) && defined(__ELF__) +#define GNUSTACK .section .note.GNU-stack,"",@progbits +#else +#define GNUSTACK +#endif /* defined(__linux__) && defined(__ELF__) */ + +#define EPILOGUE \ + .end REALNAME ;\ + GNUSTACK + +#define PROFCODE + +#define MOVT(dst, src, cc) \ + bceqz cc, 1f; \ + add.d dst, src, $r0; \ + 1: + +#endif /* defined(ASSEMBLER) && !defined(NEEDPARAM) */ + +#endif /* defined(ASSEMBLER) */ + +#define SEEK_ADDRESS + +#define BUFFER_SIZE ( 32 << 20) + +#define PAGESIZE (16UL << 10) +#define FIXED_PAGESIZE (16UL << 10) +#define HUGE_PAGESIZE ( 2 << 20) + +#define BASE_ADDRESS (START_ADDRESS - BUFFER_SIZE * MAX_CPU_NUMBER) + +#ifndef MAP_ANONYMOUS +#define MAP_ANONYMOUS MAP_ANON +#endif + +#endif diff --git a/common_macro.h b/common_macro.h index c6ea1bfd9..cf2a3fd88 100644 --- a/common_macro.h +++ b/common_macro.h @@ -644,6 +644,17 @@ #define GEADD_K DGEADD_K +#define GEMM_SMALL_MATRIX_PERMIT DGEMM_SMALL_MATRIX_PERMIT + +#define GEMM_SMALL_KERNEL_NN DGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT DGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_TN DGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT DGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_B0_NN DGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT DGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_TN DGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT DGEMM_SMALL_KERNEL_B0_TT + #elif defined(BFLOAT16) #define D_TO_BF16_K SBDTOBF16_K @@ -931,6 +942,18 @@ #define GEADD_K SGEADD_K +#define GEMM_SMALL_MATRIX_PERMIT SBGEMM_SMALL_MATRIX_PERMIT + +#define GEMM_SMALL_KERNEL_NN SBGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT SBGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_TN SBGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT SBGEMM_SMALL_KERNEL_TT + +#define GEMM_SMALL_KERNEL_B0_NN SBGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT SBGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_TN SBGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT SBGEMM_SMALL_KERNEL_B0_TT + #endif #else @@ -1236,6 +1259,19 @@ #define IMATCOPY_K_RT SIMATCOPY_K_RT #define GEADD_K SGEADD_K + +#define GEMM_SMALL_MATRIX_PERMIT SGEMM_SMALL_MATRIX_PERMIT + +#define GEMM_SMALL_KERNEL_NN SGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT SGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_TN SGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT SGEMM_SMALL_KERNEL_TT + +#define GEMM_SMALL_KERNEL_B0_NN SGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT SGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_TN SGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT SGEMM_SMALL_KERNEL_B0_TT + #endif #else #ifdef XDOUBLE @@ -2063,6 +2099,48 @@ #define GEADD_K ZGEADD_K +#define GEMM_SMALL_MATRIX_PERMIT ZGEMM_SMALL_MATRIX_PERMIT + +#define GEMM_SMALL_KERNEL_NN ZGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT ZGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_NR ZGEMM_SMALL_KERNEL_NR +#define GEMM_SMALL_KERNEL_NC ZGEMM_SMALL_KERNEL_NC + +#define GEMM_SMALL_KERNEL_TN ZGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT ZGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_TR ZGEMM_SMALL_KERNEL_TR +#define GEMM_SMALL_KERNEL_TC ZGEMM_SMALL_KERNEL_TC + +#define GEMM_SMALL_KERNEL_RN ZGEMM_SMALL_KERNEL_RN +#define GEMM_SMALL_KERNEL_RT ZGEMM_SMALL_KERNEL_RT +#define GEMM_SMALL_KERNEL_RR ZGEMM_SMALL_KERNEL_RR +#define GEMM_SMALL_KERNEL_RC ZGEMM_SMALL_KERNEL_RC + +#define GEMM_SMALL_KERNEL_CN ZGEMM_SMALL_KERNEL_CN +#define GEMM_SMALL_KERNEL_CT ZGEMM_SMALL_KERNEL_CT +#define GEMM_SMALL_KERNEL_CR ZGEMM_SMALL_KERNEL_CR +#define GEMM_SMALL_KERNEL_CC ZGEMM_SMALL_KERNEL_CC + +#define GEMM_SMALL_KERNEL_B0_NN ZGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT ZGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_NR ZGEMM_SMALL_KERNEL_B0_NR +#define GEMM_SMALL_KERNEL_B0_NC ZGEMM_SMALL_KERNEL_B0_NC + +#define GEMM_SMALL_KERNEL_B0_TN ZGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT ZGEMM_SMALL_KERNEL_B0_TT +#define GEMM_SMALL_KERNEL_B0_TR ZGEMM_SMALL_KERNEL_B0_TR +#define GEMM_SMALL_KERNEL_B0_TC ZGEMM_SMALL_KERNEL_B0_TC + +#define GEMM_SMALL_KERNEL_B0_RN ZGEMM_SMALL_KERNEL_B0_RN +#define GEMM_SMALL_KERNEL_B0_RT ZGEMM_SMALL_KERNEL_B0_RT +#define GEMM_SMALL_KERNEL_B0_RR ZGEMM_SMALL_KERNEL_B0_RR +#define GEMM_SMALL_KERNEL_B0_RC ZGEMM_SMALL_KERNEL_B0_RC + +#define GEMM_SMALL_KERNEL_B0_CN ZGEMM_SMALL_KERNEL_B0_CN +#define GEMM_SMALL_KERNEL_B0_CT ZGEMM_SMALL_KERNEL_B0_CT +#define GEMM_SMALL_KERNEL_B0_CR ZGEMM_SMALL_KERNEL_B0_CR +#define GEMM_SMALL_KERNEL_B0_CC ZGEMM_SMALL_KERNEL_B0_CC + #else #define AMAX_K CAMAX_K @@ -2486,11 +2564,54 @@ #define GEADD_K CGEADD_K +#define GEMM_SMALL_MATRIX_PERMIT CGEMM_SMALL_MATRIX_PERMIT + +#define GEMM_SMALL_KERNEL_NN CGEMM_SMALL_KERNEL_NN +#define GEMM_SMALL_KERNEL_NT CGEMM_SMALL_KERNEL_NT +#define GEMM_SMALL_KERNEL_NR CGEMM_SMALL_KERNEL_NR +#define GEMM_SMALL_KERNEL_NC CGEMM_SMALL_KERNEL_NC + +#define GEMM_SMALL_KERNEL_TN CGEMM_SMALL_KERNEL_TN +#define GEMM_SMALL_KERNEL_TT CGEMM_SMALL_KERNEL_TT +#define GEMM_SMALL_KERNEL_TR CGEMM_SMALL_KERNEL_TR +#define GEMM_SMALL_KERNEL_TC CGEMM_SMALL_KERNEL_TC + +#define GEMM_SMALL_KERNEL_RN CGEMM_SMALL_KERNEL_RN +#define GEMM_SMALL_KERNEL_RT CGEMM_SMALL_KERNEL_RT +#define GEMM_SMALL_KERNEL_RR CGEMM_SMALL_KERNEL_RR +#define GEMM_SMALL_KERNEL_RC CGEMM_SMALL_KERNEL_RC + +#define GEMM_SMALL_KERNEL_CN CGEMM_SMALL_KERNEL_CN +#define GEMM_SMALL_KERNEL_CT CGEMM_SMALL_KERNEL_CT +#define GEMM_SMALL_KERNEL_CR CGEMM_SMALL_KERNEL_CR +#define GEMM_SMALL_KERNEL_CC CGEMM_SMALL_KERNEL_CC + +#define GEMM_SMALL_KERNEL_B0_NN CGEMM_SMALL_KERNEL_B0_NN +#define GEMM_SMALL_KERNEL_B0_NT CGEMM_SMALL_KERNEL_B0_NT +#define GEMM_SMALL_KERNEL_B0_NR CGEMM_SMALL_KERNEL_B0_NR +#define GEMM_SMALL_KERNEL_B0_NC CGEMM_SMALL_KERNEL_B0_NC + +#define GEMM_SMALL_KERNEL_B0_TN CGEMM_SMALL_KERNEL_B0_TN +#define GEMM_SMALL_KERNEL_B0_TT CGEMM_SMALL_KERNEL_B0_TT +#define GEMM_SMALL_KERNEL_B0_TR CGEMM_SMALL_KERNEL_B0_TR +#define GEMM_SMALL_KERNEL_B0_TC CGEMM_SMALL_KERNEL_B0_TC + +#define GEMM_SMALL_KERNEL_B0_RN CGEMM_SMALL_KERNEL_B0_RN +#define GEMM_SMALL_KERNEL_B0_RT CGEMM_SMALL_KERNEL_B0_RT +#define GEMM_SMALL_KERNEL_B0_RR CGEMM_SMALL_KERNEL_B0_RR +#define GEMM_SMALL_KERNEL_B0_RC CGEMM_SMALL_KERNEL_B0_RC + +#define GEMM_SMALL_KERNEL_B0_CN CGEMM_SMALL_KERNEL_B0_CN +#define GEMM_SMALL_KERNEL_B0_CT CGEMM_SMALL_KERNEL_B0_CT +#define GEMM_SMALL_KERNEL_B0_CR CGEMM_SMALL_KERNEL_B0_CR +#define GEMM_SMALL_KERNEL_B0_CC CGEMM_SMALL_KERNEL_B0_CC + #endif #endif #ifndef ASSEMBLER -#if defined(ARCH_X86) || defined(ARCH_X86_64) || defined(ARCH_IA64) || defined(ARCH_MIPS64) || defined(ARCH_ARM64) +#if defined(ARCH_X86) || defined(ARCH_X86_64) || defined(ARCH_IA64) || defined(ARCH_MIPS64) || defined(ARCH_ARM64)\ +|| defined(ARCH_LOONGARCH64) extern BLASLONG gemm_offset_a; extern BLASLONG gemm_offset_b; extern BLASLONG sbgemm_p; diff --git a/common_param.h b/common_param.h index 3e3ae06f8..31fba9059 100644 --- a/common_param.h +++ b/common_param.h @@ -145,6 +145,19 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG); int (*sbneg_tcopy) (BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*sblaswp_ncopy) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG, blasint *, float *); +#ifdef SMALL_MATRIX_OPT + int (*sbgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); + + int (*sbgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + + int (*sbgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sbgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, bfloat16 * A, BLASLONG lda, float alpha, bfloat16 * B, BLASLONG ldb, float * C, BLASLONG ldc); +#endif #endif #if defined(BUILD_SINGLE) || defined(BUILD_COMPLEX) @@ -207,6 +220,20 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); int (*sgemm_otcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); #endif #ifdef BUILD_SINGLE +#ifdef SMALL_MATRIX_OPT + int (*sgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha, float beta); + + int (*sgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float beta, float * C, BLASLONG ldc); + + int (*sgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*sgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +#endif + int (*strsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); int (*strsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); int (*strsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); @@ -314,6 +341,19 @@ BLASLONG (*idmin_k) (BLASLONG, double *, BLASLONG); int (*dgemm_otcopy )(BLASLONG, BLASLONG, double *, BLASLONG, double *); #endif #ifdef BUILD_DOUBLE +#ifdef SMALL_MATRIX_OPT + int (*dgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha, double beta); + + int (*dgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double beta, double * C, BLASLONG ldc); + + int (*dgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*dgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +#endif int (*dtrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG, BLASLONG); int (*dtrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG, BLASLONG); int (*dtrsm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, double, double *, double *, double *, BLASLONG, BLASLONG); @@ -513,6 +553,50 @@ BLASLONG (*icamin_k)(BLASLONG, float *, BLASLONG); int (*cgemm_oncopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); int (*cgemm_otcopy )(BLASLONG, BLASLONG, float *, BLASLONG, float *); +#ifdef SMALL_MATRIX_OPT + int (*cgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, float alpha0, float alpha1, float beta0, float beta1); + + int (*cgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_nr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_nc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_tr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_tc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_rn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_rt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_rr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_rc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_cn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_ct )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_cr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_cc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float beta0, float beta1, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_nr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_nc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_tr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_tc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_b0_rn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_rt )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_rr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_rc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + + int (*cgemm_small_kernel_b0_cn )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_ct )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_cr )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); + int (*cgemm_small_kernel_b0_cc )(BLASLONG m, BLASLONG n, BLASLONG k, float * A, BLASLONG lda, float alpha0, float alpha1, float * B, BLASLONG ldb, float * C, BLASLONG ldc); +#endif + int (*ctrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG, BLASLONG); int (*ctrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG, BLASLONG); int (*ctrsm_kernel_LR)(BLASLONG, BLASLONG, BLASLONG, float, float, float *, float *, float *, BLASLONG, BLASLONG); @@ -679,6 +763,50 @@ BLASLONG (*izamin_k)(BLASLONG, double *, BLASLONG); int (*zgemm_oncopy )(BLASLONG, BLASLONG, double *, BLASLONG, double *); int (*zgemm_otcopy )(BLASLONG, BLASLONG, double *, BLASLONG, double *); +#ifdef SMALL_MATRIX_OPT + int (*zgemm_small_matrix_permit)(int transa, int transb, BLASLONG m, BLASLONG n, BLASLONG k, double alpha0, double alpha1, double beta0, double beta1); + + int (*zgemm_small_kernel_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_nr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_nc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_tr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_tc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_rn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_rt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_rr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_rc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_cn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_ct )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_cr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_cc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double beta0, double beta1, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_b0_nn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_nt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_nr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_nc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_b0_tn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_tt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_tr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_tc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_b0_rn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_rt )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_rr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_rc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + + int (*zgemm_small_kernel_b0_cn )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_ct )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_cr )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); + int (*zgemm_small_kernel_b0_cc )(BLASLONG m, BLASLONG n, BLASLONG k, double * A, BLASLONG lda, double alpha0, double alpha1, double * B, BLASLONG ldb, double * C, BLASLONG ldc); +#endif + int (*ztrsm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG, BLASLONG); int (*ztrsm_kernel_LT)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG, BLASLONG); int (*ztrsm_kernel_LR)(BLASLONG, BLASLONG, BLASLONG, double, double, double *, double *, double *, BLASLONG, BLASLONG); @@ -1069,6 +1197,8 @@ BLASLONG (*ixamin_k)(BLASLONG, xdouble *, BLASLONG); extern gotoblas_t *gotoblas; +#define FUNC_OFFSET(func) (size_t)(&((gotoblas_t *)NULL)->func) + #define DTB_ENTRIES gotoblas -> dtb_entries #define GEMM_OFFSET_A gotoblas -> offsetA #define GEMM_OFFSET_B gotoblas -> offsetB @@ -1174,6 +1304,8 @@ extern gotoblas_t *gotoblas; #else +#define FUNC_OFFSET(func) (size_t)(func) + #define DTB_ENTRIES DTB_DEFAULT_ENTRIES #define GEMM_OFFSET_A GEMM_DEFAULT_OFFSET_A diff --git a/common_s.h b/common_s.h index 34903ec49..fdd80b62f 100644 --- a/common_s.h +++ b/common_s.h @@ -164,6 +164,8 @@ #define SGEADD_K sgeadd_k +#define SGEMM_SMALL_MATRIX_PERMIT sgemm_small_matrix_permit + #else #define SAMAX_K gotoblas -> samax_k @@ -299,8 +301,21 @@ #define SGEADD_K gotoblas -> sgeadd_k +#define SGEMM_SMALL_MATRIX_PERMIT gotoblas -> sgemm_small_matrix_permit + #endif +#define SGEMM_SMALL_KERNEL_NN FUNC_OFFSET(sgemm_small_kernel_nn) +#define SGEMM_SMALL_KERNEL_NT FUNC_OFFSET(sgemm_small_kernel_nt) +#define SGEMM_SMALL_KERNEL_TN FUNC_OFFSET(sgemm_small_kernel_tn) +#define SGEMM_SMALL_KERNEL_TT FUNC_OFFSET(sgemm_small_kernel_tt) + +#define SGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(sgemm_small_kernel_b0_nn) +#define SGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(sgemm_small_kernel_b0_nt) +#define SGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(sgemm_small_kernel_b0_tn) +#define SGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(sgemm_small_kernel_b0_tt) + + #define SGEMM_NN sgemm_nn #define SGEMM_CN sgemm_tn #define SGEMM_TN sgemm_tn diff --git a/common_sb.h b/common_sb.h index 9976e812e..d21e7a563 100644 --- a/common_sb.h +++ b/common_sb.h @@ -24,6 +24,7 @@ #define SBGEMM_BETA sbgemm_beta #define SBGEMM_KERNEL sbgemm_kernel +#define SBGEMM_SMALL_MATRIX_PERMIT sbgemm_small_matrix_permit #else #define SBDOT_K gotoblas -> sbdot_k @@ -41,8 +42,19 @@ #define SBGEMM_BETA gotoblas -> sbgemm_beta #define SBGEMM_KERNEL gotoblas -> sbgemm_kernel +#define SBGEMM_SMALL_MATRIX_PERMIT gotoblas -> sbgemm_small_matrix_permit #endif +#define SBGEMM_SMALL_KERNEL_NN FUNC_OFFSET(sbgemm_small_kernel_nn) +#define SBGEMM_SMALL_KERNEL_NT FUNC_OFFSET(sbgemm_small_kernel_nt) +#define SBGEMM_SMALL_KERNEL_TN FUNC_OFFSET(sbgemm_small_kernel_tn) +#define SBGEMM_SMALL_KERNEL_TT FUNC_OFFSET(sbgemm_small_kernel_tt) + +#define SBGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(sbgemm_small_kernel_b0_nn) +#define SBGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(sbgemm_small_kernel_b0_nt) +#define SBGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(sbgemm_small_kernel_b0_tn) +#define SBGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(sbgemm_small_kernel_b0_tt) + #define SBGEMM_NN sbgemm_nn #define SBGEMM_CN sbgemm_tn #define SBGEMM_TN sbgemm_tn diff --git a/common_z.h b/common_z.h index f1e78dd08..c12d71b39 100644 --- a/common_z.h +++ b/common_z.h @@ -232,6 +232,8 @@ #define ZGEADD_K zgeadd_k +#define ZGEMM_SMALL_MATRIX_PERMIT zgemm_small_matrix_permit + #else #define ZAMAX_K gotoblas -> zamax_k @@ -426,8 +428,51 @@ #define ZGEADD_K gotoblas -> zgeadd_k +#define ZGEMM_SMALL_MATRIX_PERMIT gotoblas -> zgemm_small_matrix_permit + #endif +#define ZGEMM_SMALL_KERNEL_NN FUNC_OFFSET(zgemm_small_kernel_nn) +#define ZGEMM_SMALL_KERNEL_NT FUNC_OFFSET(zgemm_small_kernel_nt) +#define ZGEMM_SMALL_KERNEL_NR FUNC_OFFSET(zgemm_small_kernel_nr) +#define ZGEMM_SMALL_KERNEL_NC FUNC_OFFSET(zgemm_small_kernel_nc) + +#define ZGEMM_SMALL_KERNEL_TN FUNC_OFFSET(zgemm_small_kernel_tn) +#define ZGEMM_SMALL_KERNEL_TT FUNC_OFFSET(zgemm_small_kernel_tt) +#define ZGEMM_SMALL_KERNEL_TR FUNC_OFFSET(zgemm_small_kernel_tr) +#define ZGEMM_SMALL_KERNEL_TC FUNC_OFFSET(zgemm_small_kernel_tc) + +#define ZGEMM_SMALL_KERNEL_RN FUNC_OFFSET(zgemm_small_kernel_rn) +#define ZGEMM_SMALL_KERNEL_RT FUNC_OFFSET(zgemm_small_kernel_rt) +#define ZGEMM_SMALL_KERNEL_RR FUNC_OFFSET(zgemm_small_kernel_rr) +#define ZGEMM_SMALL_KERNEL_RC FUNC_OFFSET(zgemm_small_kernel_rc) + +#define ZGEMM_SMALL_KERNEL_CN FUNC_OFFSET(zgemm_small_kernel_cn) +#define ZGEMM_SMALL_KERNEL_CT FUNC_OFFSET(zgemm_small_kernel_ct) +#define ZGEMM_SMALL_KERNEL_CR FUNC_OFFSET(zgemm_small_kernel_cr) +#define ZGEMM_SMALL_KERNEL_CC FUNC_OFFSET(zgemm_small_kernel_cc) + +#define ZGEMM_SMALL_KERNEL_B0_NN FUNC_OFFSET(zgemm_small_kernel_b0_nn) +#define ZGEMM_SMALL_KERNEL_B0_NT FUNC_OFFSET(zgemm_small_kernel_b0_nt) +#define ZGEMM_SMALL_KERNEL_B0_NR FUNC_OFFSET(zgemm_small_kernel_b0_nr) +#define ZGEMM_SMALL_KERNEL_B0_NC FUNC_OFFSET(zgemm_small_kernel_b0_nc) + +#define ZGEMM_SMALL_KERNEL_B0_TN FUNC_OFFSET(zgemm_small_kernel_b0_tn) +#define ZGEMM_SMALL_KERNEL_B0_TT FUNC_OFFSET(zgemm_small_kernel_b0_tt) +#define ZGEMM_SMALL_KERNEL_B0_TR FUNC_OFFSET(zgemm_small_kernel_b0_tr) +#define ZGEMM_SMALL_KERNEL_B0_TC FUNC_OFFSET(zgemm_small_kernel_b0_tc) + +#define ZGEMM_SMALL_KERNEL_B0_RN FUNC_OFFSET(zgemm_small_kernel_b0_rn) +#define ZGEMM_SMALL_KERNEL_B0_RT FUNC_OFFSET(zgemm_small_kernel_b0_rt) +#define ZGEMM_SMALL_KERNEL_B0_RR FUNC_OFFSET(zgemm_small_kernel_b0_rr) +#define ZGEMM_SMALL_KERNEL_B0_RC FUNC_OFFSET(zgemm_small_kernel_b0_rc) + +#define ZGEMM_SMALL_KERNEL_B0_CN FUNC_OFFSET(zgemm_small_kernel_b0_cn) +#define ZGEMM_SMALL_KERNEL_B0_CT FUNC_OFFSET(zgemm_small_kernel_b0_ct) +#define ZGEMM_SMALL_KERNEL_B0_CR FUNC_OFFSET(zgemm_small_kernel_b0_cr) +#define ZGEMM_SMALL_KERNEL_B0_CC FUNC_OFFSET(zgemm_small_kernel_b0_cc) + + #define ZGEMM_NN zgemm_nn #define ZGEMM_CN zgemm_cn #define ZGEMM_TN zgemm_tn diff --git a/cpuid_loongarch64.c b/cpuid_loongarch64.c new file mode 100644 index 000000000..79b186bf1 --- /dev/null +++ b/cpuid_loongarch64.c @@ -0,0 +1,110 @@ +/***************************************************************************** +Copyright (c) 2011-2020, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +**********************************************************************************/ + +#include + +#define CPU_UNKNOWN 0 +#define CPU_LOONGSON3R5 1 + +#define LOONGARCH_CFG2 0x02 +#define LOONGARCH_LASX 1<<7 + +static char *cpuname[] = { + "UNKNOWN", + "LOONGSON3R5" +}; + +int detect(void) { + uint32_t reg = 0; + + __asm__ volatile ( + "cpucfg %0, %1 \n\t" + : "+&r"(reg) + : "r"(LOONGARCH_CFG2) + ); + + if (reg & LOONGARCH_LASX) + return CPU_LOONGSON3R5; + else + return CPU_UNKNOWN; +} + +char *get_corename(void) { + return cpuname[detect()]; +} + +void get_architecture(void) { + printf("LOONGARCH64"); +} + +void get_subarchitecture(void) { + if (detect() == CPU_LOONGSON3R5) { + printf("LOONGSON3R5"); + } else { + printf("UNKNOWN"); + } +} + +void get_subdirname(void) { + printf("loongarch64"); +} + +void get_cpuconfig(void) { + if (detect() == CPU_LOONGSON3R5) { + printf("#define LOONGSON3R5\n"); + printf("#define L1_DATA_SIZE 65536\n"); + printf("#define L1_DATA_LINESIZE 64\n"); + printf("#define L2_SIZE 1048576\n"); + printf("#define L2_LINESIZE 64\n"); + printf("#define DTB_DEFAULT_ENTRIES 64\n"); + printf("#define DTB_SIZE 4096\n"); + printf("#define L2_ASSOCIATIVE 16\n"); + } else { + printf("#define LOONGSON3R5\n"); + printf("#define L1_DATA_SIZE 65536\n"); + printf("#define L1_DATA_LINESIZE 64\n"); + printf("#define L2_SIZE 1048576\n"); + printf("#define L2_LINESIZE 64\n"); + printf("#define DTB_DEFAULT_ENTRIES 64\n"); + printf("#define DTB_SIZE 4096\n"); + printf("#define L2_ASSOCIATIVE 16\n"); + } +} + +void get_libname(void){ + if (detect() == CPU_LOONGSON3R5) { + printf("loongson3r5\n"); + } else { + printf("loongarch64\n"); + } +} diff --git a/ctest.c b/ctest.c index d674a8cbd..2afd93f68 100644 --- a/ctest.c +++ b/ctest.c @@ -84,7 +84,7 @@ OS_AIX OS_OSF #endif -#if defined(__WIN32) || defined(__WIN64) || defined(__WINNT) +#if defined(__WIN32) || defined(__WIN64) || defined(_WIN32) || defined(_WIN64) || defined(__WINNT) OS_WINNT #endif @@ -141,7 +141,7 @@ ARCH_SPARC ARCH_IA64 #endif -#if defined(__LP64) || defined(__LP64__) || defined(__ptr64) || defined(__x86_64__) || defined(__amd64__) || defined(__64BIT__) +#if defined(__LP64) || defined(__LP64__) || defined(__ptr64) || defined(__x86_64__) || defined(__amd64__) || defined(__64BIT__) || defined(__aarch64__) BINARY_64 #endif @@ -157,6 +157,10 @@ ARCH_ARM64 ARCH_RISCV64 #endif +#ifdef __loongarch64 +ARCH_LOONGARCH64 +#endif + #if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) HAVE_C11 #endif diff --git a/ctest/CMakeLists.txt b/ctest/CMakeLists.txt index 17f29fe69..f785d3f90 100644 --- a/ctest/CMakeLists.txt +++ b/ctest/CMakeLists.txt @@ -4,6 +4,9 @@ include_directories(${PROJECT_BINARY_DIR}) enable_language(Fortran) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DADD${BU} -DCBLAS") +if (CMAKE_Fortran_COMPILER_ID STREQUAL GNU) + set(CMAKE_Fortran_FLAGS "${CMAKE_Fortran_FLAGS} -fno-tree-vectorize") +endif() if(WIN32) FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/test_cblas_helper.ps1 diff --git a/ctest/Makefile b/ctest/Makefile index 15c83a907..c5e1094da 100644 --- a/ctest/Makefile +++ b/ctest/Makefile @@ -6,6 +6,9 @@ TOPDIR = .. include $(TOPDIR)/Makefile.system override CFLAGS += -DADD$(BU) -DCBLAS +ifeq ($(F_COMPILER),GFORTRAN) + override FFLAGS += -fno-tree-vectorize +endif override TARGET_ARCH= override TARGET_MACH= diff --git a/driver/level2/CMakeLists.txt b/driver/level2/CMakeLists.txt index 61367e596..3e9964ab1 100644 --- a/driver/level2/CMakeLists.txt +++ b/driver/level2/CMakeLists.txt @@ -81,6 +81,7 @@ foreach (float_type ${FLOAT_TYPES}) GenerateNamedObjects("gbmv_thread.c" "TRANSA" "gbmv_thread_t" false "" "" false ${float_type}) endif () +# special defines for complex if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") foreach (u_source ${U_SOURCES}) @@ -197,6 +198,13 @@ foreach (float_type ${FLOAT_TYPES}) endif () endforeach () +if (BUILD_BFLOAT16) + if (USE_THREAD) + GenerateNamedObjects("sbgemv_thread.c" "" "gemv_thread_n" false "" "" false "BFLOAT16") + GenerateNamedObjects("sbgemv_thread.c" "TRANSA" "gemv_thread_t" false "" "" false "BFLOAT16") + endif () +endif () + if ( BUILD_COMPLEX AND NOT BUILD_SINGLE) if (USE_THREAD) GenerateNamedObjects("gemv_thread.c" "" "gemv_thread_n" false "" "" false "SINGLE") diff --git a/driver/level3/CMakeLists.txt b/driver/level3/CMakeLists.txt index 077862abc..75b25d039 100644 --- a/driver/level3/CMakeLists.txt +++ b/driver/level3/CMakeLists.txt @@ -12,6 +12,12 @@ foreach (GEMM_DEFINE ${GEMM_DEFINES}) if (USE_THREAD AND NOT USE_SIMPLE_THREADED_LEVEL3) GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0) endif () + if (BUILD_BFLOAT16) + GenerateNamedObjects("gemm.c" "${GEMM_DEFINE}" "gemm_${GEMM_DEFINE_LC}" 0 "" "" false "BFLOAT16") + if (USE_THREAD AND NOT USE_SIMPLE_THREADED_LEVEL3) + GenerateNamedObjects("gemm.c" "${GEMM_DEFINE};THREADED_LEVEL3" "gemm_thread_${GEMM_DEFINE_LC}" 0 "" "" false "BFLOAT16") + endif () + endif () endforeach () if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) diff --git a/driver/others/dynamic_power.c b/driver/others/dynamic_power.c index d9c15b312..2847ea9ae 100644 --- a/driver/others/dynamic_power.c +++ b/driver/others/dynamic_power.c @@ -6,10 +6,6 @@ extern gotoblas_t gotoblas_POWER8; #if (!defined __GNUC__) || ( __GNUC__ >= 6) extern gotoblas_t gotoblas_POWER9; #endif -//#if (!defined __GNUC__) || ( __GNUC__ >= 11) \ -// || (__GNUC__ == 10 && __GNUC_MINOR__ >= 2) -//#define HAVE_P10_SUPPORT 1 -//#endif #ifdef HAVE_P10_SUPPORT extern gotoblas_t gotoblas_POWER10; #endif diff --git a/driver/others/memory.c b/driver/others/memory.c index 6e654ccf2..0185fa683 100644 --- a/driver/others/memory.c +++ b/driver/others/memory.c @@ -73,6 +73,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "common.h" +#ifndef likely +#ifdef __GNUC__ +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) +#else +#define likely(x) (x) +#define unlikely(x) (x) +#endif +#endif + #if defined(USE_TLS) && defined(SMP) #define COMPILE_TLS @@ -428,7 +438,7 @@ extern int openblas_goto_num_threads_env(); extern int openblas_omp_num_threads_env(); int blas_get_cpu_number(void){ -#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) +#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) || defined(OS_HAIKU) int max_num; #endif int blas_goto_num = 0; @@ -436,7 +446,7 @@ int blas_get_cpu_number(void){ if (blas_num_threads) return blas_num_threads; -#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) +#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) || defined(OS_HAIKU) max_num = get_num_procs(); #endif @@ -460,7 +470,7 @@ int blas_get_cpu_number(void){ else if (blas_omp_num > 0) blas_num_threads = blas_omp_num; else blas_num_threads = MAX_CPU_NUMBER; -#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) +#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) || defined(OS_HAIKU) if (blas_num_threads > max_num) blas_num_threads = max_num; #endif @@ -1291,7 +1301,12 @@ UNLOCK_COMMAND(&alloc_lock); return (void *)(((char *)alloc_info) + sizeof(struct alloc_t)); error: - printf("OpenBLAS : Program will terminate because you tried to allocate too many memory regions.\n"); + printf("OpenBLAS : Program will terminate because you tried to allocate too many TLS memory regions.\n"); + printf("This library was built to support a maximum of %d threads - either rebuild OpenBLAS\n", NUM_BUFFERS); + printf("with a larger NUM_THREADS value or set the environment variable OPENBLAS_NUM_THREADS to\n"); + printf("a sufficiently small number. This error typically occurs when the software that relies on\n"); + printf("OpenBLAS calls BLAS functions from many threads in parallel, or when your computer has more\n"); + printf("cpu cores than what OpenBLAS was configured to handle.\n"); return NULL; } @@ -1979,7 +1994,7 @@ extern int openblas_goto_num_threads_env(); extern int openblas_omp_num_threads_env(); int blas_get_cpu_number(void){ -#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) +#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) || defined(OS_HAIKU) int max_num; #endif int blas_goto_num = 0; @@ -1987,7 +2002,7 @@ int blas_get_cpu_number(void){ if (blas_num_threads) return blas_num_threads; -#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) +#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) || defined(OS_HAIKU) max_num = get_num_procs(); #endif @@ -2011,7 +2026,7 @@ int blas_get_cpu_number(void){ else if (blas_omp_num > 0) blas_num_threads = blas_omp_num; else blas_num_threads = MAX_CPU_NUMBER; -#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) +#if defined(OS_LINUX) || defined(OS_WINDOWS) || defined(OS_FREEBSD) || defined(OS_OPENBSD) || defined(OS_NETBSD) || defined(OS_DRAGONFLY) || defined(OS_DARWIN) || defined(OS_ANDROID) || defined(OS_HAIKU) if (blas_num_threads > max_num) blas_num_threads = max_num; #endif @@ -2055,6 +2070,7 @@ struct release_t { int hugetlb_allocated = 0; static struct release_t release_info[NUM_BUFFERS]; +static struct release_t *new_release_info; static int release_pos = 0; #if defined(OS_LINUX) && !defined(NO_WARMUP) @@ -2105,8 +2121,13 @@ static void *alloc_mmap(void *address){ #if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) LOCK_COMMAND(&alloc_lock); #endif + if (likely(release_pos < NUM_BUFFERS)) { release_info[release_pos].address = map_address; release_info[release_pos].func = alloc_mmap_free; + } else { + new_release_info[release_pos-NUM_BUFFERS].address = map_address; + new_release_info[release_pos-NUM_BUFFERS].func = alloc_mmap_free; + } release_pos ++; #if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) UNLOCK_COMMAND(&alloc_lock); @@ -2269,8 +2290,13 @@ static void *alloc_mmap(void *address){ #if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) LOCK_COMMAND(&alloc_lock); #endif + if (likely(release_pos < NUM_BUFFERS)) { release_info[release_pos].address = map_address; release_info[release_pos].func = alloc_mmap_free; + } else { + new_release_info[release_pos-NUM_BUFFERS].address = map_address; + new_release_info[release_pos-NUM_BUFFERS].func = alloc_mmap_free; + } release_pos ++; #if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) UNLOCK_COMMAND(&alloc_lock); @@ -2302,8 +2328,13 @@ static void *alloc_malloc(void *address){ if (map_address == (void *)NULL) map_address = (void *)-1; if (map_address != (void *)-1) { + if (likely(release_pos < NUM_BUFFERS)) { release_info[release_pos].address = map_address; release_info[release_pos].func = alloc_malloc_free; + } else { + new_release_info[release_pos-NUM_BUFFERS].address = map_address; + new_release_info[release_pos-NUM_BUFFERS].func = alloc_malloc_free; + } release_pos ++; } @@ -2336,8 +2367,13 @@ static void *alloc_qalloc(void *address){ if (map_address == (void *)NULL) map_address = (void *)-1; if (map_address != (void *)-1) { + if (likely(release_pos < NUM_BUFFERS)) { release_info[release_pos].address = map_address; release_info[release_pos].func = alloc_qalloc_free; + } else { + new_release_info[release_pos-NUM_BUFFERS].address = map_address; + new_release_info[release_pos-NUM_BUFFERS].func = alloc_qalloc_free; + } release_pos ++; } @@ -2365,8 +2401,13 @@ static void *alloc_windows(void *address){ if (map_address == (void *)NULL) map_address = (void *)-1; if (map_address != (void *)-1) { + if (likely(release_pos < NUM_BUFFERS)) { release_info[release_pos].address = map_address; release_info[release_pos].func = alloc_windows_free; + } else { + new_release_info[release_pos-NUM_BUFFERS].address = map_address; + new_release_info[release_pos-NUM_BUFFERS].func = alloc_windows_free; + } release_pos ++; } @@ -2409,9 +2450,15 @@ static void *alloc_devicedirver(void *address){ fd, 0); if (map_address != (void *)-1) { + if (likely(release_pos < NUM_BUFFERS)) { release_info[release_pos].address = map_address; release_info[release_pos].attr = fd; release_info[release_pos].func = alloc_devicedirver_free; + } else { + new_release_info[release_pos-NUM_BUFFERS].address = map_address; + new_release_info[release_pos-NUM_BUFFERS].attr = fd; + new_release_info[release_pos-NUM_BUFFERS].func = alloc_devicedirver_free; + } release_pos ++; } @@ -2445,9 +2492,15 @@ static void *alloc_shm(void *address){ shmctl(shmid, IPC_RMID, 0); + if (likely(release_pos < NUM_BUFFERS)) { release_info[release_pos].address = map_address; release_info[release_pos].attr = shmid; release_info[release_pos].func = alloc_shm_free; + } else { + new_release_info[release_pos-NUM_BUFFERS].address = map_address; + new_release_info[release_pos-NUM_BUFFERS].attr = shmid; + new_release_info[release_pos-NUM_BUFFERS].func = alloc_shm_free; + } release_pos ++; } @@ -2551,8 +2604,13 @@ static void *alloc_hugetlb(void *address){ #endif if (map_address != (void *)-1){ + if (likely(release_pos < NUM_BUFFERS)) { release_info[release_pos].address = map_address; release_info[release_pos].func = alloc_hugetlb_free; + } else { + new_release_info[release_pos-NUM_BUFFERS].address = map_address; + new_release_info[release_pos-NUM_BUFFERS].func = alloc_hugetlb_free; + } release_pos ++; } @@ -2599,9 +2657,15 @@ static void *alloc_hugetlbfile(void *address){ fd, 0); if (map_address != (void *)-1) { + if (likely(release_pos < NUM_BUFFERS)) { release_info[release_pos].address = map_address; release_info[release_pos].attr = fd; release_info[release_pos].func = alloc_hugetlbfile_free; + } else { + new_release_info[release_pos-NUM_BUFFERS].address = map_address; + new_release_info[release_pos-NUM_BUFFERS].attr = fd; + new_release_info[release_pos-NUM_BUFFERS].func = alloc_hugetlbfile_free; + } release_pos ++; } @@ -2631,8 +2695,25 @@ static volatile struct { } memory[NUM_BUFFERS]; -static int memory_initialized = 0; +struct newmemstruct +{ + BLASULONG lock; + void *addr; +#if defined(WHEREAMI) && !defined(USE_OPENMP) + int pos; +#endif + int used; +#ifndef __64BIT__ + char dummy[48]; +#else + char dummy[40]; +#endif +}; +static volatile struct newmemstruct *newmemory; + +static int memory_initialized = 0; +static int memory_overflowed = 0; /* Memory allocation routine */ /* procpos ... indicates where it comes from */ /* 0 : Level 3 functions */ @@ -2641,6 +2722,8 @@ static int memory_initialized = 0; void *blas_memory_alloc(int procpos){ + int i; + int position; #if defined(WHEREAMI) && !defined(USE_OPENMP) int mypos = 0; @@ -2774,6 +2857,29 @@ void *blas_memory_alloc(int procpos){ #if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) UNLOCK_COMMAND(&alloc_lock); #endif + if (memory_overflowed) { +#if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) + LOCK_COMMAND(&alloc_lock); +#endif + do { + RMB; +#if defined(USE_OPENMP) + if (!newmemory[position-NUM_BUFFERS].used) { + blas_lock(&newmemory[position-NUM_BUFFERS].lock); +#endif + if (!newmemory[position-NUM_BUFFERS].used) goto allocation2; + +#if defined(USE_OPENMP) + blas_unlock(&newmemory[position-NUM_BUFFERS].lock); + } +#endif + position ++; + + } while (position < 512+NUM_BUFFERS); +#if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) + UNLOCK_COMMAND(&alloc_lock); +#endif +} goto error; allocation : @@ -2878,8 +2984,97 @@ void *blas_memory_alloc(int procpos){ return (void *)memory[position].addr; error: - printf("BLAS : Program is Terminated. Because you tried to allocate too many memory regions.\n"); + if (memory_overflowed) goto terminate; + fprintf(stderr,"OpenBLAS warning: precompiled NUM_THREADS exceeded, adding auxiliary array for thread metadata.\n"); + memory_overflowed=1; + new_release_info = (struct release_t*) malloc(512*sizeof(struct release_t)); + newmemory = (struct newmemstruct*) malloc(512*sizeof(struct newmemstruct)); + for (i = 0; i < 512; i++) { + newmemory[i].addr = (void *)0; +#if defined(WHEREAMI) && !defined(USE_OPENMP) + newmemory[i].pos = -1; +#endif + newmemory[i].used = 0; + newmemory[i].lock = 0; +} + newmemory[position-NUM_BUFFERS].used = 1; + +allocation2: + newmemory[position-NUM_BUFFERS].used = 1; +#if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) + UNLOCK_COMMAND(&alloc_lock); +#else + blas_unlock(&newmemory[position-NUM_BUFFERS].lock); +#endif + do { +#ifdef DEBUG + printf("Allocation Start : %lx\n", base_address); +#endif + map_address = (void *)-1; + + func = &memoryalloc[0]; + + while ((func != NULL) && (map_address == (void *) -1)) { + + map_address = (*func)((void *)base_address); + +#ifdef ALLOC_DEVICEDRIVER + if ((*func == alloc_devicedirver) && (map_address == (void *)-1)) { + fprintf(stderr, "OpenBLAS Warning ... Physically contiguous allocation was failed.\n"); + } +#endif + +#ifdef ALLOC_HUGETLBFILE + if ((*func == alloc_hugetlbfile) && (map_address == (void *)-1)) { +#ifndef OS_WINDOWS + fprintf(stderr, "OpenBLAS Warning ... HugeTLB(File) allocation was failed.\n"); +#endif + } +#endif + +#if (defined ALLOC_SHM) && (defined OS_LINUX || defined OS_AIX || defined __sun__ || defined OS_WINDOWS) + if ((*func == alloc_hugetlb) && (map_address != (void *)-1)) hugetlb_allocated = 1; +#endif + + func ++; + } + +#ifdef DEBUG + printf(" Success -> %08lx\n", map_address); +#endif + if (((BLASLONG) map_address) == -1) base_address = 0UL; + + if (base_address) base_address += BUFFER_SIZE + FIXED_PAGESIZE; + + } while ((BLASLONG)map_address == -1); + +#if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) + LOCK_COMMAND(&alloc_lock); +#endif + newmemory[position-NUM_BUFFERS].addr = map_address; +#if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) + UNLOCK_COMMAND(&alloc_lock); +#endif + +#ifdef DEBUG + printf(" Mapping Succeeded. %p(%d)\n", (void *)newmemory[position-NUM_BUFFERS].addr, position); +#endif + +#if defined(WHEREAMI) && !defined(USE_OPENMP) + + if (newmemory[position-NUM_BUFFERS].pos == -1) newmemory[position-NUM_BUFFERS].pos = mypos; + +#endif + return (void *)newmemory[position-NUM_BUFFERS].addr; + +terminate: + printf("OpenBLAS : Program is Terminated. Because you tried to allocate too many memory regions.\n"); + printf("This library was built to support a maximum of %d threads - either rebuild OpenBLAS\n", NUM_BUFFERS); + printf("with a larger NUM_THREADS value or set the environment variable OPENBLAS_NUM_THREADS to\n"); + printf("a sufficiently small number. This error typically occurs when the software that relies on\n"); + printf("OpenBLAS calls BLAS functions from many threads in parallel, or when your computer has more\n"); + printf("cpu cores than what OpenBLAS was configured to handle.\n"); return NULL; } @@ -2898,13 +3093,28 @@ void blas_memory_free(void *free_area){ while ((position < NUM_BUFFERS) && (memory[position].addr != free_area)) position++; - if (position >= NUM_BUFFERS) goto error; + if (position >= NUM_BUFFERS && !memory_overflowed) goto error; #ifdef DEBUG if (memory[position].addr != free_area) goto error; printf(" Position : %d\n", position); #endif + if (unlikely(memory_overflowed && position >= NUM_BUFFERS)) { + while ((position < NUM_BUFFERS+512) && (newmemory[position-NUM_BUFFERS].addr != free_area)) + position++; + // arm: ensure all writes are finished before other thread takes this memory + WMB; + newmemory[position].used = 0; +#if (defined(SMP) || defined(USE_LOCKING)) && !defined(USE_OPENMP) + UNLOCK_COMMAND(&alloc_lock); +#endif + +#ifdef DEBUG + printf("Unmap from overflow area succeeded.\n\n"); +#endif + return; +} else { // arm: ensure all writes are finished before other thread takes this memory WMB; @@ -2918,7 +3128,7 @@ void blas_memory_free(void *free_area){ #endif return; - +} error: printf("BLAS : Bad memory unallocation! : %4d %p\n", position, free_area); @@ -2953,7 +3163,10 @@ void blas_shutdown(void){ LOCK_COMMAND(&alloc_lock); for (pos = 0; pos < release_pos; pos ++) { + if (likely(pos < NUM_BUFFERS)) release_info[pos].func(&release_info[pos]); + else + new_release_info[pos-NUM_BUFFERS].func(&new_release_info[pos-NUM_BUFFERS]); } #ifdef SEEK_ADDRESS @@ -2970,6 +3183,15 @@ void blas_shutdown(void){ #endif memory[pos].lock = 0; } + if (memory_overflowed) + for (pos = 0; pos < 512; pos ++){ + newmemory[pos].addr = (void *)0; + newmemory[pos].used = 0; +#if defined(WHEREAMI) && !defined(USE_OPENMP) + newmemory[pos].pos = -1; +#endif + newmemory[pos].lock = 0; + } UNLOCK_COMMAND(&alloc_lock); diff --git a/driver/others/parameter.c b/driver/others/parameter.c index 36da13369..791e5dc27 100644 --- a/driver/others/parameter.c +++ b/driver/others/parameter.c @@ -524,6 +524,9 @@ void blas_set_parameter(void){ xgemm_p = ((xgemm_p + XGEMM_UNROLL_M - 1)/XGEMM_UNROLL_M) * XGEMM_UNROLL_M; #endif +#ifdef BUILD_BFLOAT16 + sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15; +#endif sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15; dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15; cgemm_r = (((BUFFER_SIZE - ((CGEMM_P * CGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (CGEMM_Q * 8)) - 15) & ~15; @@ -629,7 +632,9 @@ void blas_set_parameter(void){ xgemm_p = 16 * (size + 1); #endif +#ifdef BUILD_BFLOAT16 sbgemm_r = (((BUFFER_SIZE - ((SBGEMM_P * SBGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SBGEMM_Q * 4)) - 15) & ~15; +#endif sgemm_r = (((BUFFER_SIZE - ((SGEMM_P * SGEMM_Q * 4 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (SGEMM_Q * 4)) - 15) & ~15; dgemm_r = (((BUFFER_SIZE - ((DGEMM_P * DGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (DGEMM_Q * 8)) - 15) & ~15; cgemm_r = (((BUFFER_SIZE - ((CGEMM_P * CGEMM_Q * 8 + GEMM_OFFSET_A + GEMM_ALIGN) & ~GEMM_ALIGN)) / (CGEMM_Q * 8)) - 15) & ~15; diff --git a/getarch.c b/getarch.c index 3bc8a0c3d..094feaadd 100644 --- a/getarch.c +++ b/getarch.c @@ -142,6 +142,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. /* #define FORCE_SICORTEX */ /* #define FORCE_LOONGSON3R3 */ /* #define FORCE_LOONGSON3R4 */ +/* #define FORCE_LOONGSON3R5 */ /* #define FORCE_I6400 */ /* #define FORCE_P6600 */ /* #define FORCE_P5600 */ @@ -312,6 +313,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define FORCE #define FORCE_INTEL #define ARCHITECTURE "X86" +#ifdef NO_AVX +#define SUBARCHITECTURE "NEHALEM" +#define ARCHCONFIG "-DNEHALEM " \ + "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ + "-DL2_SIZE=262144 -DL2_LINESIZE=64 " \ + "-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 " \ + "-DHAVE_CMOV -DHAVE_MMX -DHAVE_SSE -DHAVE_SSE2 -DHAVE_SSE3 -DHAVE_SSSE3 -DHAVE_SSE4_1 -DHAVE_SSE4_2" +#define LIBNAME "nehalem" +#define CORENAME "NEHALEM" +#else #define SUBARCHITECTURE "SANDYBRIDGE" #define ARCHCONFIG "-DSANDYBRIDGE " \ "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ @@ -321,12 +332,23 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define LIBNAME "sandybridge" #define CORENAME "SANDYBRIDGE" #endif +#endif #ifdef FORCE_HASWELL #define FORCE #define FORCE_INTEL #define ARCHITECTURE "X86" #ifdef NO_AVX2 +#ifdef NO_AVX +#define SUBARCHITECTURE "NEHALEM" +#define ARCHCONFIG "-DNEHALEM " \ + "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ + "-DL2_SIZE=262144 -DL2_LINESIZE=64 " \ + "-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 " \ + "-DHAVE_CMOV -DHAVE_MMX -DHAVE_SSE -DHAVE_SSE2 -DHAVE_SSE3 -DHAVE_SSSE3 -DHAVE_SSE4_1 -DHAVE_SSE4_2" +#define LIBNAME "nehalem" +#define CORENAME "NEHALEM" +#else #define SUBARCHITECTURE "SANDYBRIDGE" #define ARCHCONFIG "-DSANDYBRIDGE " \ "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ @@ -335,6 +357,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. "-DHAVE_CMOV -DHAVE_MMX -DHAVE_SSE -DHAVE_SSE2 -DHAVE_SSE3 -DHAVE_SSSE3 -DHAVE_SSE4_1 -DHAVE_SSE4_2 -DHAVE_AVX" #define LIBNAME "sandybridge" #define CORENAME "SANDYBRIDGE" +#endif #else #define SUBARCHITECTURE "HASWELL" #define ARCHCONFIG "-DHASWELL " \ @@ -349,10 +372,31 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endif #ifdef FORCE_SKYLAKEX -#ifdef NO_AVX512 #define FORCE #define FORCE_INTEL #define ARCHITECTURE "X86" +#ifdef NO_AVX512 +#ifdef NO_AVX2 +#ifdef NO_AVX +#define SUBARCHITECTURE "NEHALEM" +#define ARCHCONFIG "-DNEHALEM " \ + "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ + "-DL2_SIZE=262144 -DL2_LINESIZE=64 " \ + "-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 " \ + "-DHAVE_CMOV -DHAVE_MMX -DHAVE_SSE -DHAVE_SSE2 -DHAVE_SSE3 -DHAVE_SSSE3 -DHAVE_SSE4_1 -DHAVE_SSE4_2" +#define LIBNAME "nehalem" +#define CORENAME "NEHALEM" +#else +#define SUBARCHITECTURE "SANDYBRIDGE" +#define ARCHCONFIG "-DSANDYBRIDGE " \ + "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ + "-DL2_SIZE=262144 -DL2_LINESIZE=64 " \ + "-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 " \ + "-DHAVE_CMOV -DHAVE_MMX -DHAVE_SSE -DHAVE_SSE2 -DHAVE_SSE3 -DHAVE_SSSE3 -DHAVE_SSE4_1 -DHAVE_SSE4_2 -DHAVE_AVX" +#define LIBNAME "sandybridge" +#define CORENAME "SANDYBRIDGE" +#endif +#else #define SUBARCHITECTURE "HASWELL" #define ARCHCONFIG "-DHASWELL " \ "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ @@ -362,10 +406,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. "-DHAVE_AVX2 -DHAVE_FMA3 -DFMA3" #define LIBNAME "haswell" #define CORENAME "HASWELL" +#endif #else -#define FORCE -#define FORCE_INTEL -#define ARCHITECTURE "X86" #define SUBARCHITECTURE "SKYLAKEX" #define ARCHCONFIG "-DSKYLAKEX " \ "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ @@ -379,10 +421,31 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #endif #ifdef FORCE_COOPERLAKE -#ifdef NO_AVX512 #define FORCE #define FORCE_INTEL #define ARCHITECTURE "X86" +#ifdef NO_AVX512 +#ifdef NO_AVX2 +#ifdef NO_AVX +#define SUBARCHITECTURE "NEHALEM" +#define ARCHCONFIG "-DNEHALEM " \ + "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ + "-DL2_SIZE=262144 -DL2_LINESIZE=64 " \ + "-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 " \ + "-DHAVE_CMOV -DHAVE_MMX -DHAVE_SSE -DHAVE_SSE2 -DHAVE_SSE3 -DHAVE_SSSE3 -DHAVE_SSE4_1 -DHAVE_SSE4_2" +#define LIBNAME "nehalem" +#define CORENAME "NEHALEM" +#else +#define SUBARCHITECTURE "SANDYBRIDGE" +#define ARCHCONFIG "-DSANDYBRIDGE " \ + "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ + "-DL2_SIZE=262144 -DL2_LINESIZE=64 " \ + "-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 " \ + "-DHAVE_CMOV -DHAVE_MMX -DHAVE_SSE -DHAVE_SSE2 -DHAVE_SSE3 -DHAVE_SSSE3 -DHAVE_SSE4_1 -DHAVE_SSE4_2 -DHAVE_AVX" +#define LIBNAME "sandybridge" +#define CORENAME "SANDYBRIDGE" +#endif +#else #define SUBARCHITECTURE "HASWELL" #define ARCHCONFIG "-DHASWELL " \ "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ @@ -392,10 +455,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. "-DHAVE_AVX2 -DHAVE_FMA3 -DFMA3" #define LIBNAME "haswell" #define CORENAME "HASWELL" +#endif #else -#define FORCE -#define FORCE_INTEL -#define ARCHITECTURE "X86" #define SUBARCHITECTURE "COOPERLAKE" #define ARCHCONFIG "-DCOOPERLAKE " \ "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ @@ -563,6 +624,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define FORCE_INTEL #define ARCHITECTURE "X86" #ifdef NO_AVX2 +#ifdef NO_AVX +#define SUBARCHITECTURE "NEHALEM" +#define ARCHCONFIG "-DNEHALEM " \ + "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ + "-DL2_SIZE=262144 -DL2_LINESIZE=64 " \ + "-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 " \ + "-DHAVE_CMOV -DHAVE_MMX -DHAVE_SSE -DHAVE_SSE2 -DHAVE_SSE3 -DHAVE_SSSE3 -DHAVE_SSE4_1 -DHAVE_SSE4_2" +#define LIBNAME "nehalem" +#define CORENAME "NEHALEM" +#else #define SUBARCHITECTURE "SANDYBRIDGE" #define ARCHCONFIG "-DSANDYBRIDGE " \ "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ @@ -571,6 +642,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. "-DHAVE_CMOV -DHAVE_MMX -DHAVE_SSE -DHAVE_SSE2 -DHAVE_SSE3 -DHAVE_SSSE3 -DHAVE_SSE4_1 -DHAVE_SSE4_2 -DHAVE_AVX" #define LIBNAME "sandybridge" #define CORENAME "SANDYBRIDGE" +#endif #else #define SUBARCHITECTURE "ZEN" #define ARCHCONFIG "-DZEN " \ @@ -842,6 +914,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #else #endif +#ifdef FORCE_LOONGSON3R5 +#define FORCE +#define ARCHITECTURE "LOONGARCH" +#define SUBARCHITECTURE "LOONGSON3R5" +#define SUBDIRNAME "loongarch64" +#define ARCHCONFIG "-DLOONGSON3R5 " \ + "-DL1_DATA_SIZE=65536 -DL1_DATA_LINESIZE=64 " \ + "-DL2_SIZE=1048576 -DL2_LINESIZE=64 " \ + "-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 -DL2_ASSOCIATIVE=16 " +#define LIBNAME "loongson3r5" +#define CORENAME "LOONGSON3R5" +#else +#endif + #ifdef FORCE_I6400 #define FORCE #define ARCHITECTURE "MIPS" @@ -1388,6 +1474,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define OPENBLAS_SUPPORTED #endif +#ifdef __loongarch64 +#include "cpuid_loongarch64.c" +#define OPENBLAS_SUPPORTED +#endif + #ifdef __riscv #include "cpuid_riscv64.c" #define OPENBLAS_SUPPORTED @@ -1463,7 +1554,7 @@ int main(int argc, char *argv[]){ #ifdef FORCE printf("CORE=%s\n", CORENAME); #else -#if defined(INTEL_AMD) || defined(POWER) || defined(__mips__) || defined(__arm__) || defined(__aarch64__) || defined(ZARCH) || defined(sparc) +#if defined(INTEL_AMD) || defined(POWER) || defined(__mips__) || defined(__arm__) || defined(__aarch64__) || defined(ZARCH) || defined(sparc) || defined(__loongarch__) printf("CORE=%s\n", get_corename()); #endif #endif @@ -1611,7 +1702,7 @@ printf("ELF_VERSION=2\n"); #ifdef FORCE printf("#define CHAR_CORENAME \"%s\"\n", CORENAME); #else -#if defined(INTEL_AMD) || defined(POWER) || defined(__mips__) || defined(__arm__) || defined(__aarch64__) || defined(ZARCH) || defined(sparc) +#if defined(INTEL_AMD) || defined(POWER) || defined(__mips__) || defined(__arm__) || defined(__aarch64__) || defined(ZARCH) || defined(sparc) || defined(__loongarch__) printf("#define CHAR_CORENAME \"%s\"\n", get_corename()); #endif #endif diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt index 5346ecadd..ccb5fce3f 100644 --- a/interface/CMakeLists.txt +++ b/interface/CMakeLists.txt @@ -82,6 +82,7 @@ foreach (CBLAS_FLAG ${CBLAS_FLAGS}) GenerateNamedObjects("${BLAS3_SOURCES}" "" "" ${CBLAS_FLAG} "" "" false ${DISABLE_COMPLEX}) GenerateNamedObjects("${BLAS3_MANGLED_SOURCES}" "" "" ${CBLAS_FLAG} "" "" false ${MANGLE_COMPLEX}) + GenerateNamedObjects("xerbla.c" "" "xerbla" ${CBLAS_FLAG} "" "" true) #sdsdot, dsdot if (BUILD_SINGLE OR BUILD_DOUBLE) GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE") @@ -104,6 +105,15 @@ endif () GenerateNamedObjects("imax.c" "USE_ABS;USE_MIN" "i*amin" ${CBLAS_FLAG}) GenerateNamedObjects("imax.c" "USE_MIN" "i*min" ${CBLAS_FLAG}) +if (BUILD_BFLOAT16) + GenerateNamedObjects("bf16dot.c" "" "sbdot" ${CBLAS_FLAG} "" "" true "BFLOAT16") + GenerateNamedObjects("gemm.c" "" "sbgemm" ${CBLAS_FLAG} "" "" true "BFLOAT16") + GenerateNamedObjects("sbgemv.c" "" "sbgemv" ${CBLAS_FLAG} "" "" true "BFLOAT16") + GenerateNamedObjects("tobf16.c" "SINGLE_PREC" "sbstobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16") + GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16") + GenerateNamedObjects("bf16to.c" "SINGLE_PREC" "sbf16tos" ${CBLAS_FLAG} "" "" true "BFLOAT16") + GenerateNamedObjects("bf16to.c" "DOUBLE_PREC" "dbf16tod" ${CBLAS_FLAG} "" "" true "BFLOAT16") +endif () # complex-specific sources foreach (float_type ${FLOAT_TYPES}) diff --git a/interface/gemm.c b/interface/gemm.c index 10426fd8f..71cc77a1b 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -105,6 +105,55 @@ static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, B #endif }; +#if defined(SMALL_MATRIX_OPT) && !defined(GEMM3M) && !defined(XDOUBLE) +#define USE_SMALL_MATRIX_OPT 1 +#else +#define USE_SMALL_MATRIX_OPT 0 +#endif + +#if USE_SMALL_MATRIX_OPT +#ifndef DYNAMIC_ARCH +#define SMALL_KERNEL_ADDR(table, idx) ((void *)(table[idx])) +#else +#define SMALL_KERNEL_ADDR(table, idx) ((void *)(*(uintptr_t *)((char *)gotoblas + (size_t)(table[idx])))) +#endif + + +#ifndef COMPLEX +static size_t gemm_small_kernel[] = { + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, 0, 0, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, 0, 0, +}; + + +static size_t gemm_small_kernel_b0[] = { + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, 0, 0, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, 0, 0, +}; + +#define GEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel_b0, (idx)) +#define GEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, IFLOAT *, BLASLONG, FLOAT, IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(gemm_small_kernel, (idx)) +#else + +static size_t zgemm_small_kernel[] = { + GEMM_SMALL_KERNEL_NN, GEMM_SMALL_KERNEL_TN, GEMM_SMALL_KERNEL_RN, GEMM_SMALL_KERNEL_CN, + GEMM_SMALL_KERNEL_NT, GEMM_SMALL_KERNEL_TT, GEMM_SMALL_KERNEL_RT, GEMM_SMALL_KERNEL_CT, + GEMM_SMALL_KERNEL_NR, GEMM_SMALL_KERNEL_TR, GEMM_SMALL_KERNEL_RR, GEMM_SMALL_KERNEL_CR, + GEMM_SMALL_KERNEL_NC, GEMM_SMALL_KERNEL_TC, GEMM_SMALL_KERNEL_RC, GEMM_SMALL_KERNEL_CC, +}; + +static size_t zgemm_small_kernel_b0[] = { + GEMM_SMALL_KERNEL_B0_NN, GEMM_SMALL_KERNEL_B0_TN, GEMM_SMALL_KERNEL_B0_RN, GEMM_SMALL_KERNEL_B0_CN, + GEMM_SMALL_KERNEL_B0_NT, GEMM_SMALL_KERNEL_B0_TT, GEMM_SMALL_KERNEL_B0_RT, GEMM_SMALL_KERNEL_B0_CT, + GEMM_SMALL_KERNEL_B0_NR, GEMM_SMALL_KERNEL_B0_TR, GEMM_SMALL_KERNEL_B0_RR, GEMM_SMALL_KERNEL_B0_CR, + GEMM_SMALL_KERNEL_B0_NC, GEMM_SMALL_KERNEL_B0_TC, GEMM_SMALL_KERNEL_B0_RC, GEMM_SMALL_KERNEL_B0_CC, +}; + +#define ZGEMM_SMALL_KERNEL(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel, (idx)) +#define ZGEMM_SMALL_KERNEL_B0(idx) (int (*)(BLASLONG, BLASLONG, BLASLONG, FLOAT *, BLASLONG, FLOAT , FLOAT, FLOAT *, BLASLONG, FLOAT *, BLASLONG)) SMALL_KERNEL_ADDR(zgemm_small_kernel_b0, (idx)) +#endif +#endif + #ifndef CBLAS void NAME(char *TRANSA, char *TRANSB, @@ -224,8 +273,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS blasint m, blasint n, blasint k, #ifndef COMPLEX FLOAT alpha, - FLOAT *a, blasint lda, - FLOAT *b, blasint ldb, + IFLOAT *a, blasint lda, + IFLOAT *b, blasint ldb, FLOAT beta, FLOAT *c, blasint ldc) { #else @@ -277,7 +326,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS PRINT_DEBUG_CNAME; -#if !defined(COMPLEX) && !defined(DOUBLE) && defined(USE_SGEMM_KERNEL_DIRECT) +#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && defined(USE_SGEMM_KERNEL_DIRECT) #ifdef DYNAMIC_ARCH if (support_avx512() ) #endif @@ -417,6 +466,28 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS FUNCTION_PROFILE_START(); +#if USE_SMALL_MATRIX_OPT +#if !defined(COMPLEX) + if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, *(FLOAT *)(args.alpha), *(FLOAT *)(args.beta))){ + if(*(FLOAT *)(args.beta) == 0.0){ + (GEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, args.c, args.ldc); + }else{ + (GEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, *(FLOAT *)(args.alpha), args.b, args.ldb, *(FLOAT *)(args.beta), args.c, args.ldc); + } + return; + } +#else + if(GEMM_SMALL_MATRIX_PERMIT(transa, transb, args.m, args.n, args.k, alpha[0], alpha[1], beta[0], beta[1])){ + if(beta[0] == 0.0 && beta[1] == 0.0){ + (ZGEMM_SMALL_KERNEL_B0((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, args.c, args.ldc); + }else{ + (ZGEMM_SMALL_KERNEL((transb << 2) | transa))(args.m, args.n, args.k, args.a, args.lda, alpha[0], alpha[1], args.b, args.ldb, beta[0], beta[1], args.c, args.ldc); + } + return; + } +#endif +#endif + buffer = (XFLOAT *)blas_memory_alloc(0); sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A); diff --git a/interface/zsyr.c b/interface/zsyr.c index 71d4dbf29..54fb8a4e9 100644 --- a/interface/zsyr.c +++ b/interface/zsyr.c @@ -119,7 +119,7 @@ void NAME(char *UPLO, blasint *N, FLOAT *ALPHA, void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, int n, FLOAT alpha, FLOAT *x, int incx, FLOAT *a, int lda) { FLOAT *buffer; - int trans, uplo; + int uplo; blasint info; FLOAT * ALPHA = α FLOAT alpha_r = ALPHA[0]; @@ -130,7 +130,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo, int n, FLOAT alpha, FLO PRINT_DEBUG_CNAME; - trans = -1; uplo = -1; info = 0; diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index f0793bdef..9ffbd944f 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -91,6 +91,15 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${DSDOTKERNEL}" "DSDOT" "d*dot_k" false "" "" false "SINGLE") GenerateNamedObjects("${KERNELDIR}/${DSDOTKERNEL}" "DSDOT" "dsdot_k" false "" "" false "SINGLE") + # sbdot + if (BUILD_BFLOAT16) + GenerateNamedObjects("${KERNELDIR}/${SBDOTKERNEL}" "SBDOT" "dot_k" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${BF16TOKERNEL}" "SINGLE" "f16tos_k" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${BF16TOKERNEL}" "DOUBLE" "bf16tod_k" false "" "" false "DOUBLE") + GenerateNamedObjects("${KERNELDIR}/${TOBF16KERNEL}" "SINGLE" "stobf16_k" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${TOBF16KERNEL}" "DOUBLE" "dtobf16_k" false "" "" false "BFLOAT16") + endif() + if ((BUILD_COMPLEX OR BUILD_DOUBLE) AND NOT BUILD_SINGLE) GenerateNamedObjects("${KERNELDIR}/${SAMAXKERNEL}" "USE_ABS" "amax_k" false "" "" false "SINGLE") GenerateNamedObjects("${KERNELDIR}/${SAMINKERNEL}" "USE_ABS;USE_MIN" "amin_k" false "" "" false "SINGLE") @@ -149,9 +158,6 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("generic/ger.c" "" "ger_k" false "" "" "" 3) foreach (float_type ${FLOAT_TYPES}) string(SUBSTRING ${float_type} 0 1 float_char) - if (${float_type} STREQUAL "BFLOAT16") - set (float_char "SB") - endif () if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") GenerateNamedObjects("${KERNELDIR}/${${float_char}GERUKERNEL}" "" "geru_k" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}GERCKERNEL}" "CONJ" "gerc_k" false "" "" false ${float_type}) @@ -185,6 +191,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${SGEMVNKERNEL}" "" "gemv_n" false "" "" false "SINGLE") GenerateNamedObjects("${KERNELDIR}/${SGEMVTKERNEL}" "TRANS" "gemv_t" false "" "" false "SINGLE") endif () + if (BUILD_BFLOAT16) + GenerateNamedObjects("${KERNELDIR}/${SBGEMVNKERNEL}" "" "gemv_n" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMVTKERNEL}" "" "gemv_t" false "" "" false "BFLOAT16") + endif () # Makefile.L3 set(USE_TRMM false) string(TOUPPER ${TARGET_CORE} UC_TARGET_CORE) @@ -209,15 +219,8 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${SGEMMDIRECTPERFORMANT}" "" "gemm_direct_performant" false "" "" false SINGLE) endif() - foreach (float_type SINGLE DOUBLE BFLOAT16) + foreach (float_type SINGLE DOUBLE) string(SUBSTRING ${float_type} 0 1 float_char) - if (${float_type} STREQUAL "BFLOAT16") - if (NOT ${BUILD_BFLOAT16}) - continue () - else () - set (float_char "SB") - endif () - endif () GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type}) endforeach() if (BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) @@ -253,11 +256,24 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${SGEMM_BETA}" "" "gemm_beta" false "" "" false "SINGLE") endif () + if (BUILD_BFLOAT16) + if (SBGEMMINCOPY) + GenerateNamedObjects("${KERNELDIR}/${SBGEMMINCOPY}" "" "${SBGEMMINCOPYOBJ}" false "" "" true "BFLOAT16") + endif () + if (SBGEMMITCOPY) + GenerateNamedObjects("${KERNELDIR}/${SBGEMMITCOPY}" "" "${SBGEMMITCOPYOBJ}" false "" "" true "BFLOAT16") + endif () + if (SBGEMMONCOPY) + GenerateNamedObjects("${KERNELDIR}/${SBGEMMONCOPY}" "" "${SBGEMMONCOPYOBJ}" false "" "" true "BFLOAT16") + endif () + if (SBGEMMOTCOPY) + GenerateNamedObjects("${KERNELDIR}/${SBGEMMOTCOPY}" "" "${SBGEMMOTCOPYOBJ}" false "" "" true "BFLOAT16") + endif () + GenerateNamedObjects("${KERNELDIR}/${SBGEMMKERNEL}" "" "gemm_kernel" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_BETA}" "" "gemm_beta" false "" "" false "BFLOAT16") + endif () foreach (float_type ${FLOAT_TYPES}) string(SUBSTRING ${float_type} 0 1 float_char) - if (${float_type} STREQUAL "BFLOAT16") - set (float_char "SB") - endif () if (${float_char}GEMMINCOPY) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMINCOPY}" "${float_type}" "${${float_char}GEMMINCOPYOBJ}" false "" "" true ${float_type}) endif () @@ -458,7 +474,155 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("${KERNELDIR}/${${float_char}TRSMKERNEL_RN}" "UPPER;RN;TRSMKERNEL" "trsm_kernel_RN" false "" "" false ${float_type}) GenerateNamedObjects("${KERNELDIR}/${${float_char}TRSMKERNEL_RT}" "RT;TRSMKERNEL" "trsm_kernel_RT" false "" "" false ${float_type}) + if (NOT DEFINED ${float_char}GEMM_SMALL_M_PERMIT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_M_PERMIT ../generic/zgemm_small_matrix_permit.c) + else () + set(${float_char}GEMM_SMALL_M_PERMIT ../generic/gemm_small_matrix_permit.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_NN) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_NN ../generic/zgemm_small_matrix_kernel_nn.c) + else () + set(${float_char}GEMM_SMALL_K_NN ../generic/gemm_small_matrix_kernel_nn.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_NT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_NT ../generic/zgemm_small_matrix_kernel_nt.c) + else () + set(${float_char}GEMM_SMALL_K_NT ../generic/gemm_small_matrix_kernel_nt.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_TN) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_TN ../generic/zgemm_small_matrix_kernel_tn.c) + else () + set(${float_char}GEMM_SMALL_K_TN ../generic/gemm_small_matrix_kernel_tn.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_TT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_TT ../generic/zgemm_small_matrix_kernel_tt.c) + else () + set(${float_char}GEMM_SMALL_K_TT ../generic/gemm_small_matrix_kernel_tt.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_NN) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_B0_NN ../generic/zgemm_small_matrix_kernel_nn.c) + else () + set(${float_char}GEMM_SMALL_K_B0_NN ../generic/gemm_small_matrix_kernel_nn.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_NT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_B0_NT ../generic/zgemm_small_matrix_kernel_nt.c) + else () + set(${float_char}GEMM_SMALL_K_B0_NT ../generic/gemm_small_matrix_kernel_nt.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_TN) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_B0_TN ../generic/zgemm_small_matrix_kernel_tn.c) + else () + set(${float_char}GEMM_SMALL_K_B0_TN ../generic/gemm_small_matrix_kernel_tn.c) + endif () + endif () + if (NOT DEFINED ${float_char}GEMM_SMALL_K_B0_TT) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + set(${float_char}GEMM_SMALL_K_B0_TT ../generic/zgemm_small_matrix_kernel_tt.c) + else () + set(${float_char}GEMM_SMALL_K_B0_TT ../generic/gemm_small_matrix_kernel_tt.c) + endif () + endif () + if (SMALL_MATRIX_OPT) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_M_PERMIT}" "" "gemm_small_matrix_permit" false "" "" false ${float_type}) + if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "NN" "gemm_small_kernel_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "NR" "gemm_small_kernel_nr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "RN" "gemm_small_kernel_rn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "RR" "gemm_small_kernel_rr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "NT" "gemm_small_kernel_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "NC" "gemm_small_kernel_nc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "RT" "gemm_small_kernel_rt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "RC" "gemm_small_kernel_rc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "TN" "gemm_small_kernel_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "TR" "gemm_small_kernel_tr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "CN" "gemm_small_kernel_cn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "CR" "gemm_small_kernel_cr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "TT" "gemm_small_kernel_tt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "TC" "gemm_small_kernel_tc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "CT" "gemm_small_kernel_ct" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TT}" "CC" "gemm_small_kernel_cc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NN;B0" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "NR;B0" "gemm_small_kernel_b0_nr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RN;B0" "gemm_small_kernel_b0_rn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "RR;B0" "gemm_small_kernel_b0_rr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NT;B0" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "NC;B0" "gemm_small_kernel_b0_nc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RT;B0" "gemm_small_kernel_b0_rt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "RC;B0" "gemm_small_kernel_b0_rc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TN;B0" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "TR;B0" "gemm_small_kernel_b0_tr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CN;B0" "gemm_small_kernel_b0_cn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "CR;B0" "gemm_small_kernel_b0_cr" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TT;B0" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "TC;B0" "gemm_small_kernel_b0_tc" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CT;B0" "gemm_small_kernel_b0_ct" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TT}" "CC;B0" "gemm_small_kernel_b0_cc" false "" "" false ${float_type}) + + else () + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NN}" "" "gemm_small_kernel_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "" "gemm_small_kernel_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_TN}" "" "gemm_small_kernel_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_NT}" "" "gemm_small_kernel_tt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NN}" "B0" "gemm_small_kernel_b0_nn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_nt" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false ${float_type}) + GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false ${float_type}) + endif () + if (BUILD_BFLOAT16) + if (NOT DEFINED SBGEMM_SMALL_M_PERMIT) + set(SBGEMM_SMALL_M_PERMIT ../generic/gemm_small_matrix_permit.c) + endif () + if (NOT DEFINED SBGEMM_SMALL_K_NN) + set(SBGEMM_SMALL_K_NN ../generic/gemm_small_matrix_kernel_nn.c) + endif () + if (NOT DEFINED SBGEMM_SMALL_K_NT) + set(SBGEMM_SMALL_K_NT ../generic/gemm_small_matrix_kernel_nt.c) + endif () + if (NOT DEFINED SBGEMM_SMALL_K_TN) + set(SBGEMM_SMALL_K_TN ../generic/gemm_small_matrix_kernel_tn.c) + endif () + if (NOT DEFINED SBGEMM_SMALL_K_TT) + set(SBGEMM_SMALL_K_TT ../generic/gemm_small_matrix_kernel_tt.c) + endif () + if (NOT DEFINED SBGEMM_SMALL_K_B0_NN) + set(SBGEMM_SMALL_K_B0_NN ../generic/gemm_small_matrix_kernel_nn.c) + endif () + if (NOT DEFINED SBGEMM_SMALL_K_B0_NT) + set(SBGEMM_SMALL_K_B0_NT ../generic/gemm_small_matrix_kernel_nt.c) + endif () + if (NOT DEFINED SBGEMM_SMALL_K_B0_TN) + set(SBGEMM_SMALL_K_B0_TN ../generic/gemm_small_matrix_kernel_tn.c) + endif () + if (NOT DEFINED SBGEMM_SMALL_K_B0_TT) + set($SBGEMM_SMALL_K_B0_TT ../generic/gemm_small_matrix_kernel_tt.c) + endif () + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_M_PERMIT}" "" "gemm_small_matrix_permit" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_NN}" "" "gemm_small_kernel_nn" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_NT}" "" "gemm_small_kernel_nt" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_TN}" "" "gemm_small_kernel_tn" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_NT}" "" "gemm_small_kernel_tt" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_NN}" "B0" "gemm_small_kernel_b0_nn" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_nt" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_TN}" "B0" "gemm_small_kernel_b0_tn" false "" "" false "BFLOAT16") + GenerateNamedObjects("${KERNELDIR}/${SBGEMM_SMALL_K_B0_NT}" "B0" "gemm_small_kernel_b0_tt" false "" "" false "BFLOAT16") + endif () + endif () if (NOT DEFINED ${float_char}OMATCOPY_CN) if (${float_char} STREQUAL "Z" OR ${float_char} STREQUAL "C") @@ -592,6 +756,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) #geadd GenerateNamedObjects("${KERNELDIR}/${${float_char}GEADD_KERNEL}" "" "geadd_k" false "" "" false ${float_type}) endforeach () + if (BUILD_DOUBLE AND NOT BUILD_SINGLE) GenerateNamedObjects("${KERNELDIR}/${STRSMKERNEL_LN}" "UPPER;LN;TRSMKERNEL" "trsm_kernel_LN" false "" "" false "SINGLE") GenerateNamedObjects("${KERNELDIR}/${STRSMKERNEL_LT}" "LT;TRSMKERNEL" "trsm_kernel_LT" false "" "" false "SINGLE") @@ -730,22 +895,22 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) GenerateNamedObjects("generic/trsm_ltcopy_${SGEMM_UNROLL_N}.c" "OUTER;LOWER" "trsm_oltncopy" false "" ${TSUFFIX} false "SINGLE") if (SGEMMINCOPY) - GenerateNamedObjects("${KERNELDIR}/${SGEMMINCOPY}" "SINGLE" "${SGEMMINCOPYOBJ}" false "" "" true "SINGLE") + GenerateNamedObjects("${KERNELDIR}/${SGEMMINCOPY}" "SINGLE" "${SGEMMINCOPYOBJ}" false "" "" true "SINGLE") endif () - if (SGEMMITCOPY) - GenerateNamedObjects("${KERNELDIR}/${SGEMMITCOPY}" "SINGLE" "${SGEMMITCOPYOBJ}" false "" "" true "SINGLE") - endif () - if (SGEMMONCOPY) - GenerateNamedObjects("${KERNELDIR}/${SGEMMONCOPY}" "SINGLE" "${SGEMMONCOPYOBJ}" false "" "" true "SINGLE") - endif () - if (SGEMMOTCOPY) - GenerateNamedObjects("${KERNELDIR}/${SGEMMOTCOPY}" "SINGLE" "${SGEMMOTCOPYOBJ}" false "" "" true "SINGLE") + if (SGEMMITCOPY) + GenerateNamedObjects("${KERNELDIR}/${SGEMMITCOPY}" "SINGLE" "${SGEMMITCOPYOBJ}" false "" "" true "SINGLE") + endif () + if (SGEMMONCOPY) + GenerateNamedObjects("${KERNELDIR}/${SGEMMONCOPY}" "SINGLE" "${SGEMMONCOPYOBJ}" false "" "" true "SINGLE") + endif () + if (SGEMMOTCOPY) + GenerateNamedObjects("${KERNELDIR}/${SGEMMOTCOPY}" "SINGLE" "${SGEMMOTCOPYOBJ}" false "" "" true "SINGLE") endif () GenerateNamedObjects("${KERNELDIR}/${SGEMVNKERNEL}" "" "gemv_n" false "" "" false "SINGLE") GenerateNamedObjects("${KERNELDIR}/${SGEMVTKERNEL}" "TRANS" "gemv_t" false "" "" false "SINGLE") endif () - - if (BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) + + if (BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) GenerateNamedObjects("generic/neg_tcopy_${DGEMM_UNROLL_M}.c" "" "neg_tcopy" false "" ${TSUFFIX} false "DOUBLE") GenerateNamedObjects("generic/laswp_ncopy_${DGEMM_UNROLL_N}.c" "" "laswp_ncopy" false "" ${TSUFFIX} false "DOUBLE") endif () diff --git a/kernel/Makefile.L2 b/kernel/Makefile.L2 index 888a9b959..ac53c29c3 100644 --- a/kernel/Makefile.L2 +++ b/kernel/Makefile.L2 @@ -1,3 +1,10 @@ +FMAFLAG= +ifndef OLDGCC +ifdef HAVE_FMA3 +FMAFLAG = -mfma +endif +endif + ### GEMV ### ifndef SGEMVNKERNEL @@ -263,7 +270,7 @@ $(KDIR)dgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)dgemv_n$(TSUFFIX).$(PSUFFIX) : $(KER $(CC) -c $(CFLAGS) -DDOUBLE -UCOMPLEX -UTRANS $< -o $@ $(KDIR)dgemv_t$(TSUFFIX).$(SUFFIX) $(KDIR)dgemv_t$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(DGEMVTKERNEL) $(TOPDIR)/common.h $(GEMVDEP) - $(CC) -c $(CFLAGS) -DDOUBLE -UCOMPLEX -DTRANS $< -o $@ + $(CC) -c $(CFLAGS) $(FMAFLAG) -DDOUBLE -UCOMPLEX -DTRANS $< -o $@ endif $(KDIR)qgemv_n$(TSUFFIX).$(SUFFIX) $(KDIR)qgemv_n$(TSUFFIX).$(PSUFFIX) : $(KERNELDIR)/$(QGEMVNKERNEL) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 2d9e3ec36..2d274d33b 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -447,6 +447,72 @@ XBLASOBJS += \ endif +###### BLAS small matrix optimization ##### +ifeq ($(SMALL_MATRIX_OPT), 1) + +ifeq ($(BUILD_BFLOAT16),1) +SBBLASOBJS += \ + sbgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ + sbgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + sbgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + sbgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) +endif + +SBLASOBJS += \ + sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) + +DBLASOBJS += \ + dgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) + +CBLASOBJS += \ + cgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) \ + cgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) cgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) + +ZBLASOBJS += \ + zgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) \ + zgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) zgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) + +endif + ###### BLAS extensions ##### ifeq ($(BUILD_SINGLE),1) @@ -4237,3 +4303,469 @@ endif $(KDIR)zgeadd_k$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEADD_K) $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -UROWM $< -o $@ + + +###### BLAS small matrix optimization ##### + +ifndef DGEMM_SMALL_M_PERMIT +DGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c +endif + +ifndef DGEMM_SMALL_K_NN +DGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef DGEMM_SMALL_K_NT +DGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef DGEMM_SMALL_K_TN +DGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef DGEMM_SMALL_K_TT +DGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)dgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)dgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX $< -o $@ + +ifndef DGEMM_SMALL_K_B0_NN +DGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef DGEMM_SMALL_K_B0_NT +DGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef DGEMM_SMALL_K_B0_TN +DGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef DGEMM_SMALL_K_B0_TT +DGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)dgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)dgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)dgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)dgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -UCOMPLEX -DB0 $< -o $@ + +ifndef SGEMM_SMALL_M_PERMIT +SGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c +endif + +ifndef SGEMM_SMALL_K_NN +SGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef SGEMM_SMALL_K_NT +SGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef SGEMM_SMALL_K_TN +SGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef SGEMM_SMALL_K_TT +SGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)sgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ + +ifndef SGEMM_SMALL_K_B0_NN +SGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef SGEMM_SMALL_K_B0_NT +SGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef SGEMM_SMALL_K_B0_TN +SGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef SGEMM_SMALL_K_B0_TT +SGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)sgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)sgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)sgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)sgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + + +ifeq ($(BUILD_BFLOAT16), 1) +ifndef SBGEMM_SMALL_M_PERMIT +SBGEMM_SMALL_M_PERMIT = ../generic/gemm_small_matrix_permit.c +endif + +ifndef SBGEMM_SMALL_K_NN +SBGEMM_SMALL_K_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef SBGEMM_SMALL_K_NT +SBGEMM_SMALL_K_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef SBGEMM_SMALL_K_TN +SBGEMM_SMALL_K_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef SBGEMM_SMALL_K_TT +SBGEMM_SMALL_K_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)sbgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sbgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sbgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sbgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +$(KDIR)sbgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX $< -o $@ + +ifndef SBGEMM_SMALL_K_B0_NN +SBGEMM_SMALL_K_B0_NN = ../generic/gemm_small_matrix_kernel_nn.c +endif + +ifndef SBGEMM_SMALL_K_B0_NT +SBGEMM_SMALL_K_B0_NT = ../generic/gemm_small_matrix_kernel_nt.c +endif + +ifndef SBGEMM_SMALL_K_B0_TN +SBGEMM_SMALL_K_B0_TN = ../generic/gemm_small_matrix_kernel_tn.c +endif + +ifndef SBGEMM_SMALL_K_B0_TT +SBGEMM_SMALL_K_B0_TT = ../generic/gemm_small_matrix_kernel_tt.c +endif + +$(KDIR)sbgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)sbgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)sbgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ + +$(KDIR)sbgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SBGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DBFLOAT16 -UDOUBLE -UCOMPLEX -DB0 $< -o $@ +endif + +ifndef CGEMM_SMALL_M_PERMIT +CGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c +endif + +ifndef CGEMM_SMALL_K_NN +CGEMM_SMALL_K_NN = ../generic/zgemm_small_matrix_kernel_nn.c +endif + +ifndef CGEMM_SMALL_K_NT +CGEMM_SMALL_K_NT = ../generic/zgemm_small_matrix_kernel_nt.c +endif + +ifndef CGEMM_SMALL_K_TN +CGEMM_SMALL_K_TN = ../generic/zgemm_small_matrix_kernel_tn.c +endif + +ifndef CGEMM_SMALL_K_TT +CGEMM_SMALL_K_TT = ../generic/zgemm_small_matrix_kernel_tt.c +endif + +$(KDIR)cgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX $< -o $@ + +$(KDIR)cgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNN $< -o $@ + +$(KDIR)cgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNR $< -o $@ + +$(KDIR)cgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRN $< -o $@ + +$(KDIR)cgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRR $< -o $@ + +$(KDIR)cgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNT $< -o $@ + +$(KDIR)cgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNC $< -o $@ + +$(KDIR)cgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRT $< -o $@ + +$(KDIR)cgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRC=RC $< -o $@ + +$(KDIR)cgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTN $< -o $@ + +$(KDIR)cgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTR $< -o $@ + +$(KDIR)cgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCN $< -o $@ + +$(KDIR)cgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCR=CR $< -o $@ + +$(KDIR)cgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTT $< -o $@ + +$(KDIR)cgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTC $< -o $@ + +$(KDIR)cgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCT $< -o $@ + +$(KDIR)cgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC $< -o $@ + +ifndef CGEMM_SMALL_K_B0_NN +CGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_nn.c +endif + +ifndef CGEMM_SMALL_K_B0_NT +CGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_nt.c +endif + +ifndef CGEMM_SMALL_K_B0_TN +CGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_tn.c +endif + +ifndef CGEMM_SMALL_K_B0_TT +CGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_tt.c +endif + +$(KDIR)cgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNN -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNR -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRN -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRR -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNT -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DNC -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRT -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DRC=RC -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTN -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTR -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCN -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCR=CR -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTT -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DTC -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCT -DB0 $< -o $@ + +$(KDIR)cgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(CGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -UDOUBLE -DCOMPLEX -DCC -DB0 $< -o $@ + +ifndef ZGEMM_SMALL_M_PERMIT +ZGEMM_SMALL_M_PERMIT = ../generic/zgemm_small_matrix_permit.c +endif + +ifndef ZGEMM_SMALL_K_NN +ZGEMM_SMALL_K_NN = ../generic/zgemm_small_matrix_kernel_nn.c +endif + +ifndef ZGEMM_SMALL_K_NT +ZGEMM_SMALL_K_NT = ../generic/zgemm_small_matrix_kernel_nt.c +endif + +ifndef ZGEMM_SMALL_K_TN +ZGEMM_SMALL_K_TN = ../generic/zgemm_small_matrix_kernel_tn.c +endif + +ifndef ZGEMM_SMALL_K_TT +ZGEMM_SMALL_K_TT = ../generic/zgemm_small_matrix_kernel_tt.c +endif + +$(KDIR)zgemm_small_matrix_permit$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_M_PERMIT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX $< -o $@ + + +$(KDIR)zgemm_small_kernel_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNN $< -o $@ + +$(KDIR)zgemm_small_kernel_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNR $< -o $@ + +$(KDIR)zgemm_small_kernel_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRN $< -o $@ + +$(KDIR)zgemm_small_kernel_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRR $< -o $@ + +$(KDIR)zgemm_small_kernel_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNT $< -o $@ + +$(KDIR)zgemm_small_kernel_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNC $< -o $@ + +$(KDIR)zgemm_small_kernel_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRT $< -o $@ + +$(KDIR)zgemm_small_kernel_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRC=RC $< -o $@ + +$(KDIR)zgemm_small_kernel_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTN $< -o $@ + +$(KDIR)zgemm_small_kernel_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTR $< -o $@ + +$(KDIR)zgemm_small_kernel_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCN $< -o $@ + +$(KDIR)zgemm_small_kernel_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCR=CR $< -o $@ + +$(KDIR)zgemm_small_kernel_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTT $< -o $@ + +$(KDIR)zgemm_small_kernel_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTC $< -o $@ + +$(KDIR)zgemm_small_kernel_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCT $< -o $@ + +$(KDIR)zgemm_small_kernel_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCC $< -o $@ + +ifndef ZGEMM_SMALL_K_B0_NN +ZGEMM_SMALL_K_B0_NN = ../generic/zgemm_small_matrix_kernel_nn.c +endif + +ifndef ZGEMM_SMALL_K_B0_NT +ZGEMM_SMALL_K_B0_NT = ../generic/zgemm_small_matrix_kernel_nt.c +endif + +ifndef ZGEMM_SMALL_K_B0_TN +ZGEMM_SMALL_K_B0_TN = ../generic/zgemm_small_matrix_kernel_tn.c +endif + +ifndef ZGEMM_SMALL_K_B0_TT +ZGEMM_SMALL_K_B0_TT = ../generic/zgemm_small_matrix_kernel_tt.c +endif + +$(KDIR)zgemm_small_kernel_b0_nn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNN -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_nr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNR -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_rn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRN -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_rr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRR -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_nt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNT -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_nc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DNC -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_rt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRT -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_rc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_NT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DRC=RC -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_tn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTN -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_tr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTR -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_cn$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCN -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_cr$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TN) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCR=CR -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_tt$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTT -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_tc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DTC -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_ct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCT -DB0 $< -o $@ + +$(KDIR)zgemm_small_kernel_b0_cc$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(ZGEMM_SMALL_K_B0_TT) + $(CC) $(CFLAGS) -c -DDOUBLE -DCOMPLEX -DCC -DB0 $< -o $@ diff --git a/kernel/arm64/dgemm_tcopy_8.S b/kernel/arm64/dgemm_tcopy_8.S index 9ab51ff57..7e5bf6080 100644 --- a/kernel/arm64/dgemm_tcopy_8.S +++ b/kernel/arm64/dgemm_tcopy_8.S @@ -50,11 +50,10 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define B03 x16 #define B04 x17 -#define I x18 -#define J x19 +#define I x19 +#define J x20 -#define TEMP1 x20 -#define TEMP2 x21 +#define TEMP1 x21 #define A_PREFETCH 2560 #define B_PREFETCH 256 diff --git a/kernel/arm64/dtrmm_kernel_8x4.S b/kernel/arm64/dtrmm_kernel_8x4.S index 0ac5a5f24..3d953266c 100644 --- a/kernel/arm64/dtrmm_kernel_8x4.S +++ b/kernel/arm64/dtrmm_kernel_8x4.S @@ -49,9 +49,10 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define pCRow3 x15 #define pA x16 #define alpha x17 -#define temp x18 +//#define temp x18 #define tempOffset x19 #define tempK x20 +#define temp x21 #define alpha0 d10 #define alphaV0 v10.d[0] diff --git a/kernel/arm64/sgemm_tcopy_16.S b/kernel/arm64/sgemm_tcopy_16.S index 46198b3a2..431f1ae2a 100644 --- a/kernel/arm64/sgemm_tcopy_16.S +++ b/kernel/arm64/sgemm_tcopy_16.S @@ -30,7 +30,7 @@ All rights reserved. #define B00 x22 -#define I x18 +#define I x21 #define J x19 #define TEMP1 x20 diff --git a/kernel/arm64/strmm_kernel_16x4.S b/kernel/arm64/strmm_kernel_16x4.S index 985a0a9a6..a44326aeb 100644 --- a/kernel/arm64/strmm_kernel_16x4.S +++ b/kernel/arm64/strmm_kernel_16x4.S @@ -49,9 +49,10 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define pCRow3 x15 #define pA x16 #define alpha w17 -#define temp x18 +//#define temp x18 #define tempOffset x19 #define tempK x20 +#define temp x21 #define alpha0 s10 #define alphaV0 v10.s[0] diff --git a/kernel/arm64/zgemm_kernel_4x4.S b/kernel/arm64/zgemm_kernel_4x4.S index f8e877f3c..a65c4f581 100644 --- a/kernel/arm64/zgemm_kernel_4x4.S +++ b/kernel/arm64/zgemm_kernel_4x4.S @@ -48,8 +48,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define pCRow2 x14 #define pCRow3 x15 #define pA x16 -#define alphaR x17 -#define alphaI x18 +#define alphaR x19 +#define alphaI x20 #define alpha0_R d10 #define alphaV0_R v10.d[0] diff --git a/kernel/arm64/ztrmm_kernel_4x4.S b/kernel/arm64/ztrmm_kernel_4x4.S index 462acfe2b..cd053b896 100644 --- a/kernel/arm64/ztrmm_kernel_4x4.S +++ b/kernel/arm64/ztrmm_kernel_4x4.S @@ -49,7 +49,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define pCRow3 x15 #define pA x16 #define alphaR x17 -#define alphaI x18 +#define alphaI x22 #define temp x19 #define tempOffset x20 #define tempK x21 diff --git a/kernel/generic/dot.c b/kernel/generic/dot.c index 5abbb735c..84568ee0b 100644 --- a/kernel/generic/dot.c +++ b/kernel/generic/dot.c @@ -47,7 +47,6 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y) if ( (inc_x == 1) && (inc_y == 1) ) { - int n1 = n & -4; #if V_SIMD && !defined(DSDOT) const int vstep = v_nlanes_f32; const int unrollx4 = n & (-vstep * 4); @@ -84,6 +83,7 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y) } dot = v_sum_f32(vsum0); #elif defined(DSDOT) + int n1 = n & -4; for (; i < n1; i += 4) { dot += (double) y[i] * (double) x[i] @@ -92,6 +92,7 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y) + (double) y[i+3] * (double) x[i+3] ; } #else + int n1 = n & -4; for (; i < n1; i += 4) { dot += y[i] * x[i] diff --git a/kernel/generic/gemm_small_matrix_kernel_nn.c b/kernel/generic/gemm_small_matrix_kernel_nn.c new file mode 100644 index 000000000..543e7e047 --- /dev/null +++ b/kernel/generic/gemm_small_matrix_kernel_nn.c @@ -0,0 +1,56 @@ +/*************************************************************************** +Copyright (c) 2020, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#ifdef B0 +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + //naive implemtation + //Column major + + BLASLONG i,j,k; + FLOAT result=0.0; + + for(i=0; i= 16 ) + if ( n >= 32) { BLASLONG align = ((32 - ((uintptr_t)x & (uintptr_t)0x1F)) >> 3) & 0x3; for (i = 0; i < align; i++) { sumf += ABS(x[i]); } } - n1 = (n-i) & -16; + n1 = (n-i) & -32; if ( n1 > 0 ) { sumf += dasum_kernel_16(n1, &x[i]); diff --git a/kernel/power/dasum_microk_power10.c b/kernel/power/dasum_microk_power10.c index d1a21b4d1..110627fa4 100644 --- a/kernel/power/dasum_microk_power10.c +++ b/kernel/power/dasum_microk_power10.c @@ -34,6 +34,19 @@ static double dasum_kernel_16 (long n, double *x) __vector double t1; __vector double t2; __vector double t3; + __vector double t4; + __vector double t5; + __vector double t6; + __vector double t7; + __vector double a0; + __vector double a1; + __vector double a2; + __vector double a3; + __vector double a4; + __vector double a5; + __vector double a6; + __vector double a7; + __asm__ ( @@ -48,14 +61,27 @@ static double dasum_kernel_16 (long n, double *x) "xxlxor 38, 38, 38 \n\t" "xxlxor 39, 39, 39 \n\t" + "xxlxor %x11, %x11, %x11 \n\t" + "xxlxor %x12, %x12, %x12 \n\t" + "xxlxor %x13, %x13, %x13 \n\t" + "xxlxor %x14, %x14, %x14 \n\t" + "xxlxor %x15, %x15, %x15 \n\t" + "xxlxor %x16, %x16, %x16 \n\t" + "xxlxor %x17, %x17, %x17 \n\t" + "xxlxor %x18, %x18, %x18 \n\t" + "lxvp 40, 0(%2) \n\t" "lxvp 42, 32(%2) \n\t" "lxvp 44, 64(%2) \n\t" "lxvp 46, 96(%2) \n\t" + "lxvp 52, 128(%2) \n\t" + "lxvp 54, 160(%2) \n\t" + "lxvp 56, 192(%2) \n\t" + "lxvp 58, 224(%2) \n\t" - "addi %2, %2, 128 \n\t" + "addi %2, %2, 256 \n\t" - "addic. %1, %1, -16 \n\t" + "addic. %1, %1, -32 \n\t" "ble two%= \n\t" ".align 5 \n" @@ -65,33 +91,52 @@ static double dasum_kernel_16 (long n, double *x) "xvabsdp 49, 41 \n\t" "xvabsdp 50, 42 \n\t" "xvabsdp 51, 43 \n\t" - "lxvp 40, 0(%2) \n\t" - "xvabsdp %x3, 44 \n\t" "xvabsdp %x4, 45 \n\t" - "lxvp 42, 32(%2) \n\t" - - "xvabsdp %x5, 46 \n\t" "xvabsdp %x6, 47 \n\t" - "lxvp 44, 64(%2) \n\t" - "xvadddp 32, 32, 48 \n\t" "xvadddp 33, 33, 49 \n\t" - - "lxvp 46, 96(%2) \n\t" - "xvadddp 34, 34, 50 \n\t" "xvadddp 35, 35, 51 \n\t" - "addi %2, %2, 128 \n\t" + "lxvp 40, 0(%2) \n\t" + "lxvp 42, 32(%2) \n\t" + "lxvp 44, 64(%2) \n\t" + "lxvp 46, 96(%2) \n\t" + "xvadddp 36, 36, %x3 \n\t" "xvadddp 37, 37, %x4 \n\t" - "addic. %1, %1, -16 \n\t" "xvadddp 38, 38, %x5 \n\t" "xvadddp 39, 39, %x6 \n\t" + "xvabsdp 60, 52 \n\t" + "xvabsdp 61, 53 \n\t" + "xvabsdp 62, 54 \n\t" + "xvabsdp 63, 55 \n\t" + + "xvabsdp %x7, 56 \n\t" + "xvabsdp %x8, 57 \n\t" + "xvabsdp %x9, 58 \n\t" + "xvabsdp %x10, 59 \n\t" + + "xvadddp %x11, %x11, 60 \n\t" + "xvadddp %x12, %x12, 61 \n\t" + "xvadddp %x13, %x13, 62 \n\t" + "xvadddp %x14, %x14, 63 \n\t" + + "lxvp 52, 128(%2) \n\t" + "lxvp 54, 160(%2) \n\t" + "lxvp 56, 192(%2) \n\t" + "lxvp 58, 224(%2) \n\t" + "xvadddp %x15, %x15, %x7 \n\t" + "xvadddp %x16, %x16, %x8 \n\t" + "xvadddp %x17, %x17, %x9 \n\t" + "xvadddp %x18, %x18, %x10 \n\t" + "addi %2, %2, 256 \n\t" + "addic. %1, %1, -32 \n\t" + "bgt one%= \n" "two%=: \n\t" @@ -114,6 +159,25 @@ static double dasum_kernel_16 (long n, double *x) "xvadddp 38, 38, %x5 \n\t" "xvadddp 39, 39, %x6 \n\t" + "xvabsdp 60, 52 \n\t" + "xvabsdp 61, 53 \n\t" + "xvabsdp 62, 54 \n\t" + "xvabsdp 63, 55 \n\t" + + "xvabsdp %x7, 56 \n\t" + "xvabsdp %x8, 57 \n\t" + "xvabsdp %x9, 58 \n\t" + "xvabsdp %x10, 59 \n\t" + "xvadddp %x11, %x11, 60 \n\t" + "xvadddp %x12, %x12, 61 \n\t" + "xvadddp %x13, %x13, 62 \n\t" + "xvadddp %x14, %x14, 63 \n\t" + + "xvadddp %x15, %x15, %x7 \n\t" + "xvadddp %x16, %x16, %x8 \n\t" + "xvadddp %x17, %x17, %x9 \n\t" + "xvadddp %x18, %x18, %x10 \n\t" + "xvadddp 32, 32, 33 \n\t" "xvadddp 34, 34, 35 \n\t" "xvadddp 36, 36, 37 \n\t" @@ -122,7 +186,18 @@ static double dasum_kernel_16 (long n, double *x) "xvadddp 32, 32, 34 \n\t" "xvadddp 36, 36, 38 \n\t" + "xvadddp %x11, %x11, %x12 \n\t" + "xvadddp %x13, %x13, %x14 \n\t" + "xvadddp %x15, %x15, %x16 \n\t" + "xvadddp %x17, %x17, %x18 \n\t" + + "xvadddp %x11, %x11, %x13 \n\t" + "xvadddp %x15, %x15, %x17 \n\t" + + "xvadddp %x11, %x11, %x15 \n\t" + "xvadddp 32, 32, 36 \n\t" + "xvadddp 32, 32, %x11 \n\t" XXSWAPD_S(33,32) "xsadddp %x0, 32, 33 \n" @@ -136,14 +211,27 @@ static double dasum_kernel_16 (long n, double *x) "=wa" (t0), // 3 "=wa" (t1), // 4 "=wa" (t2), // 5 - "=wa" (t3) // 6 + "=wa" (t3), // 6 + "=wa" (t4), // 7 + "=wa" (t5), // 8 + "=wa" (t6), // 9 + "=wa" (t7), // 10 + "=wa" (a0), // 11 + "=wa" (a1), // 12 + "=wa" (a2), // 13 + "=wa" (a3), // 14 + "=wa" (a4), // 15 + "=wa" (a5), // 16 + "=wa" (a6), // 17 + "=wa" (a7) // 18 : "m" (*x) : "cr0", "vs32","vs33","vs34","vs35","vs36","vs37","vs38","vs39", "vs40","vs41","vs42","vs43","vs44","vs45","vs46","vs47", - "vs48","vs49","vs50","vs51" + "vs48","vs49","vs50","vs51","vs52","vs53","vs54","vs55", + "vs56","vs57","vs58","vs59","vs60","vs61","vs62","vs63" ); return sum; diff --git a/kernel/power/drot.c b/kernel/power/drot.c index 3229878e4..30c7411cc 100644 --- a/kernel/power/drot.c +++ b/kernel/power/drot.c @@ -110,8 +110,6 @@ int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT { BLASLONG i=0; BLASLONG ix=0,iy=0; - FLOAT *x1=x; - FLOAT *y1=y; FLOAT temp; if ( n <= 0 ) return(0); @@ -139,7 +137,7 @@ int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT BLASLONG n1 = n & -16; if ( n1 > 0 ) { - drot_kernel_16(n1, x1, y1, c, s); + drot_kernel_16(n1, x, y, c, s); i=n1; } #endif diff --git a/kernel/power/idamax.c b/kernel/power/idamax.c index 5016f67dd..f1ef00066 100644 --- a/kernel/power/idamax.c +++ b/kernel/power/idamax.c @@ -330,10 +330,10 @@ BLASLONG CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x) { if (inc_x == 1) { - BLASLONG n1 = n & -32; #if defined(_CALL_ELF) && (_CALL_ELF == 2) #if defined(__VEC__) || defined(__ALTIVEC__) + BLASLONG n1 = n & -32; if (n1 > 0) { max = diamax_kernel_32(n1, x, &maxf); diff --git a/kernel/power/trsm_kernel_LN_power10.c b/kernel/power/trsm_kernel_LN_power10.c index 5ca1603a6..246c3a236 100644 --- a/kernel/power/trsm_kernel_LN_power10.c +++ b/kernel/power/trsm_kernel_LN_power10.c @@ -389,7 +389,6 @@ static inline __attribute__ ((always_inline)) void solve16x8(FLOAT *a, FLOAT *b, vector FLOAT *Vc6 = (vector FLOAT *) c6; vector FLOAT *Vc7 = (vector FLOAT *) c7; vector FLOAT VbS0, VbS1, VbS2, VbS3, VbS4, VbS5, VbS6, VbS7; - int j; b[120] = (c0[15] *= a[255]); b[121] = (c1[15] *= a[255]); diff --git a/kernel/power/trsm_kernel_LT_power10.c b/kernel/power/trsm_kernel_LT_power10.c index 14ff12fe4..51f3a4e61 100644 --- a/kernel/power/trsm_kernel_LT_power10.c +++ b/kernel/power/trsm_kernel_LT_power10.c @@ -390,7 +390,6 @@ static inline __attribute__ ((always_inline)) void solve16x8(FLOAT *a, FLOAT *b, vector FLOAT *Vc6 = (vector FLOAT *) c6; vector FLOAT *Vc7 = (vector FLOAT *) c7; vector FLOAT VbS0, VbS1, VbS2, VbS3, VbS4, VbS5, VbS6, VbS7; - int j; b[0] = (c0[0] *= a[0]); b[1] = (c1[0] *= a[0]); diff --git a/kernel/power/zgemv_n_4.c b/kernel/power/zgemv_n_4.c index 1f7199c89..366c21681 100644 --- a/kernel/power/zgemv_n_4.c +++ b/kernel/power/zgemv_n_4.c @@ -607,7 +607,6 @@ static void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest, FLOAT int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha_r, FLOAT alpha_i, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT * buffer) { BLASLONG i; - BLASLONG j; FLOAT *a_ptr; FLOAT *x_ptr; FLOAT *y_ptr; diff --git a/kernel/power/zgemv_n_power10.c b/kernel/power/zgemv_n_power10.c index f5bb8d70e..a545b00d8 100644 --- a/kernel/power/zgemv_n_power10.c +++ b/kernel/power/zgemv_n_power10.c @@ -738,7 +738,6 @@ static void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest, FLOAT int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha_r, FLOAT alpha_i, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT * buffer) { BLASLONG i; - BLASLONG j; FLOAT *a_ptr; FLOAT *x_ptr; FLOAT *y_ptr; diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 1e846a61c..19b7b5f0b 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -112,6 +112,11 @@ gotoblas_t TABLE_NAME = { #else NULL,NULL, #endif +#ifdef SMALL_MATRIX_OPT + sbgemm_small_matrix_permitTS, + sbgemm_small_kernel_nnTS, sbgemm_small_kernel_ntTS, sbgemm_small_kernel_tnTS, sbgemm_small_kernel_ttTS, + sbgemm_small_kernel_b0_nnTS, sbgemm_small_kernel_b0_ntTS, sbgemm_small_kernel_b0_tnTS, sbgemm_small_kernel_b0_ttTS, +#endif #endif #if ( BUILD_SINGLE==1) || (BUILD_DOUBLE==1) || (BUILD_COMPLEX==1) || (BUILD_COMPLEX16==1) @@ -171,6 +176,14 @@ gotoblas_t TABLE_NAME = { sgemm_oncopyTS, sgemm_otcopyTS, #endif +#if BUILD_SINGLE == 1 +#ifdef SMALL_MATRIX_OPT + sgemm_small_matrix_permitTS, + sgemm_small_kernel_nnTS, sgemm_small_kernel_ntTS, sgemm_small_kernel_tnTS, sgemm_small_kernel_ttTS, + sgemm_small_kernel_b0_nnTS, sgemm_small_kernel_b0_ntTS, sgemm_small_kernel_b0_tnTS, sgemm_small_kernel_b0_ttTS, +#endif +#endif + #if (BUILD_SINGLE==1) || (BUILD_DOUBLE==1) strsm_kernel_LNTS, strsm_kernel_LTTS, strsm_kernel_RNTS, strsm_kernel_RTTS, #if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N @@ -257,6 +270,11 @@ gotoblas_t TABLE_NAME = { #endif #if (BUILD_DOUBLE==1) +#ifdef SMALL_MATRIX_OPT + dgemm_small_matrix_permitTS, + dgemm_small_kernel_nnTS, dgemm_small_kernel_ntTS, dgemm_small_kernel_tnTS, dgemm_small_kernel_ttTS, + dgemm_small_kernel_b0_nnTS, dgemm_small_kernel_b0_ntTS, dgemm_small_kernel_b0_tnTS, dgemm_small_kernel_b0_ttTS, +#endif dtrsm_kernel_LNTS, dtrsm_kernel_LTTS, dtrsm_kernel_RNTS, dtrsm_kernel_RTTS, #if DGEMM_DEFAULT_UNROLL_M != DGEMM_DEFAULT_UNROLL_N dtrsm_iunucopyTS, dtrsm_iunncopyTS, dtrsm_iutucopyTS, dtrsm_iutncopyTS, @@ -389,6 +407,18 @@ gotoblas_t TABLE_NAME = { #endif cgemm_oncopyTS, cgemm_otcopyTS, +#ifdef SMALL_MATRIX_OPT + cgemm_small_matrix_permitTS, + cgemm_small_kernel_nnTS, cgemm_small_kernel_ntTS, cgemm_small_kernel_nrTS, cgemm_small_kernel_ncTS, + cgemm_small_kernel_tnTS, cgemm_small_kernel_ttTS, cgemm_small_kernel_trTS, cgemm_small_kernel_tcTS, + cgemm_small_kernel_rnTS, cgemm_small_kernel_rtTS, cgemm_small_kernel_rrTS, cgemm_small_kernel_rcTS, + cgemm_small_kernel_cnTS, cgemm_small_kernel_ctTS, cgemm_small_kernel_crTS, cgemm_small_kernel_ccTS, + cgemm_small_kernel_b0_nnTS, cgemm_small_kernel_b0_ntTS, cgemm_small_kernel_b0_nrTS, cgemm_small_kernel_b0_ncTS, + cgemm_small_kernel_b0_tnTS, cgemm_small_kernel_b0_ttTS, cgemm_small_kernel_b0_trTS, cgemm_small_kernel_b0_tcTS, + cgemm_small_kernel_b0_rnTS, cgemm_small_kernel_b0_rtTS, cgemm_small_kernel_b0_rrTS, cgemm_small_kernel_b0_rcTS, + cgemm_small_kernel_b0_cnTS, cgemm_small_kernel_b0_ctTS, cgemm_small_kernel_b0_crTS, cgemm_small_kernel_b0_ccTS, +#endif + ctrsm_kernel_LNTS, ctrsm_kernel_LTTS, ctrsm_kernel_LRTS, ctrsm_kernel_LCTS, ctrsm_kernel_RNTS, ctrsm_kernel_RTTS, ctrsm_kernel_RRTS, ctrsm_kernel_RCTS, @@ -533,6 +563,18 @@ gotoblas_t TABLE_NAME = { #endif zgemm_oncopyTS, zgemm_otcopyTS, +#ifdef SMALL_MATRIX_OPT + zgemm_small_matrix_permitTS, + zgemm_small_kernel_nnTS, zgemm_small_kernel_ntTS, zgemm_small_kernel_nrTS, zgemm_small_kernel_ncTS, + zgemm_small_kernel_tnTS, zgemm_small_kernel_ttTS, zgemm_small_kernel_trTS, zgemm_small_kernel_tcTS, + zgemm_small_kernel_rnTS, zgemm_small_kernel_rtTS, zgemm_small_kernel_rrTS, zgemm_small_kernel_rcTS, + zgemm_small_kernel_cnTS, zgemm_small_kernel_ctTS, zgemm_small_kernel_crTS, zgemm_small_kernel_ccTS, + zgemm_small_kernel_b0_nnTS, zgemm_small_kernel_b0_ntTS, zgemm_small_kernel_b0_nrTS, zgemm_small_kernel_b0_ncTS, + zgemm_small_kernel_b0_tnTS, zgemm_small_kernel_b0_ttTS, zgemm_small_kernel_b0_trTS, zgemm_small_kernel_b0_tcTS, + zgemm_small_kernel_b0_rnTS, zgemm_small_kernel_b0_rtTS, zgemm_small_kernel_b0_rrTS, zgemm_small_kernel_b0_rcTS, + zgemm_small_kernel_b0_cnTS, zgemm_small_kernel_b0_ctTS, zgemm_small_kernel_b0_crTS, zgemm_small_kernel_b0_ccTS, +#endif + ztrsm_kernel_LNTS, ztrsm_kernel_LTTS, ztrsm_kernel_LRTS, ztrsm_kernel_LCTS, ztrsm_kernel_RNTS, ztrsm_kernel_RTTS, ztrsm_kernel_RRTS, ztrsm_kernel_RCTS, diff --git a/kernel/x86_64/KERNEL.COOPERLAKE b/kernel/x86_64/KERNEL.COOPERLAKE index 0b2f3c0ed..dba94aea8 100644 --- a/kernel/x86_64/KERNEL.COOPERLAKE +++ b/kernel/x86_64/KERNEL.COOPERLAKE @@ -1 +1,22 @@ include $(KERNELDIR)/KERNEL.SKYLAKEX + +SBGEMM_SMALL_M_PERMIT = sbgemm_small_kernel_permit_cooperlake.c +SBGEMM_SMALL_K_NN = sbgemm_small_kernel_nn_cooperlake.c +SBGEMM_SMALL_K_B0_NN = sbgemm_small_kernel_nn_cooperlake.c +SBGEMM_SMALL_K_NT = sbgemm_small_kernel_nt_cooperlake.c +SBGEMM_SMALL_K_B0_NT = sbgemm_small_kernel_nt_cooperlake.c +SBGEMM_SMALL_K_TN = sbgemm_small_kernel_tn_cooperlake.c +SBGEMM_SMALL_K_B0_TN = sbgemm_small_kernel_tn_cooperlake.c +SBGEMM_SMALL_K_TT = sbgemm_small_kernel_tt_cooperlake.c +SBGEMM_SMALL_K_B0_TT = sbgemm_small_kernel_tt_cooperlake.c + +SBGEMM_BETA = sgemm_beta_skylakex.c +SBGEMMKERNEL = sbgemm_kernel_16x4_cooperlake.c +SBGEMMINCOPY = sbgemm_ncopy_16_cooperlake.c +SBGEMMITCOPY = sbgemm_tcopy_16_cooperlake.c +SBGEMMONCOPY = sbgemm_ncopy_4_cooperlake.c +SBGEMMOTCOPY = sbgemm_tcopy_4_cooperlake.c +SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX) +SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX) +SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) +SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) diff --git a/kernel/x86_64/KERNEL.SKYLAKEX b/kernel/x86_64/KERNEL.SKYLAKEX index 3d71584fe..6b4961bc2 100644 --- a/kernel/x86_64/KERNEL.SKYLAKEX +++ b/kernel/x86_64/KERNEL.SKYLAKEX @@ -10,6 +10,15 @@ STRSMKERNEL_LN = ../generic/trsm_kernel_LN.c STRSMKERNEL_LT = ../generic/trsm_kernel_LT.c STRSMKERNEL_RN = ../generic/trsm_kernel_RN.c STRSMKERNEL_RT = ../generic/trsm_kernel_RT.c +SGEMM_SMALL_M_PERMIT = sgemm_small_kernel_permit_skylakex.c +SGEMM_SMALL_K_NN = sgemm_small_kernel_nn_skylakex.c +SGEMM_SMALL_K_B0_NN = sgemm_small_kernel_nn_skylakex.c +SGEMM_SMALL_K_NT = sgemm_small_kernel_nt_skylakex.c +SGEMM_SMALL_K_B0_NT = sgemm_small_kernel_nt_skylakex.c +SGEMM_SMALL_K_TN = sgemm_small_kernel_tn_skylakex.c +SGEMM_SMALL_K_B0_TN = sgemm_small_kernel_tn_skylakex.c +SGEMM_SMALL_K_TT = sgemm_small_kernel_tt_skylakex.c +SGEMM_SMALL_K_B0_TT = sgemm_small_kernel_tt_skylakex.c DGEMMKERNEL = dgemm_kernel_16x2_skylakex.c DTRMMKERNEL = dgemm_kernel_16x2_skylakex.c @@ -18,6 +27,15 @@ DGEMMITCOPY = dgemm_tcopy_16_skylakex.c DGEMMONCOPY = ../generic/gemm_ncopy_2.c DGEMMOTCOPY = ../generic/gemm_tcopy_2.c DTRSMKERNEL_RN = ../generic/trsm_kernel_RN.c +DGEMM_SMALL_M_PERMIT = dgemm_small_kernel_permit_skylakex.c +DGEMM_SMALL_K_NN = dgemm_small_kernel_nn_skylakex.c +DGEMM_SMALL_K_B0_NN = dgemm_small_kernel_nn_skylakex.c +DGEMM_SMALL_K_NT = dgemm_small_kernel_nt_skylakex.c +DGEMM_SMALL_K_B0_NT = dgemm_small_kernel_nt_skylakex.c +DGEMM_SMALL_K_TN = dgemm_small_kernel_tn_skylakex.c +DGEMM_SMALL_K_B0_TN = dgemm_small_kernel_tn_skylakex.c +DGEMM_SMALL_K_TT = dgemm_small_kernel_tt_skylakex.c +DGEMM_SMALL_K_B0_TT = dgemm_small_kernel_tt_skylakex.c SGEMM_BETA = sgemm_beta_skylakex.c DGEMM_BETA = dgemm_beta_skylakex.c diff --git a/kernel/x86_64/bf16_common_macros.h b/kernel/x86_64/bf16_common_macros.h index 1014ecc4d..cdb4beff6 100644 --- a/kernel/x86_64/bf16_common_macros.h +++ b/kernel/x86_64/bf16_common_macros.h @@ -29,6 +29,16 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#define _MM512_BROADCASTD_EPI32(addr, zmm) \ + __asm__ ("vpbroadcastd (%1), %0;" \ + : "=v" (zmm) \ + : "r" (addr) ) + +#define PREFETCH_T0(addr) \ + __asm__ ("prefetcht0 (%0);" \ + : \ + : "r" (addr) ) + #define EXTRACT_LOW_256_FROM_512_2X(reg256, reg512) \ reg256##_0 = _mm512_castps512_ps256(reg512##_0); \ reg256##_1 = _mm512_castps512_ps256(reg512##_1); @@ -46,25 +56,25 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define BF16_MATRIX_LOAD_8x16(regArray, a, lda, idx_m, idx_n) \ - regArray##_0 = _mm256_loadu_si256(&a[(idx_m+0)*lda + idx_n]); \ - regArray##_1 = _mm256_loadu_si256(&a[(idx_m+1)*lda + idx_n]); \ - regArray##_2 = _mm256_loadu_si256(&a[(idx_m+2)*lda + idx_n]); \ - regArray##_3 = _mm256_loadu_si256(&a[(idx_m+3)*lda + idx_n]); \ - regArray##_4 = _mm256_loadu_si256(&a[(idx_m+4)*lda + idx_n]); \ - regArray##_5 = _mm256_loadu_si256(&a[(idx_m+5)*lda + idx_n]); \ - regArray##_6 = _mm256_loadu_si256(&a[(idx_m+6)*lda + idx_n]); \ - regArray##_7 = _mm256_loadu_si256(&a[(idx_m+7)*lda + idx_n]); + regArray##_0 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+0)*lda + idx_n])); \ + regArray##_1 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+1)*lda + idx_n])); \ + regArray##_2 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+2)*lda + idx_n])); \ + regArray##_3 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+3)*lda + idx_n])); \ + regArray##_4 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+4)*lda + idx_n])); \ + regArray##_5 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+5)*lda + idx_n])); \ + regArray##_6 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+6)*lda + idx_n])); \ + regArray##_7 = _mm256_loadu_si256((__m256i *)(&a[(idx_m+7)*lda + idx_n])); #define BF16_MATRIX_LOAD_8x8(regArray, a, lda, idx_m, idx_n) \ - regArray##_0 = _mm_loadu_si128(&a[(idx_m+0)*lda + idx_n]); \ - regArray##_1 = _mm_loadu_si128(&a[(idx_m+1)*lda + idx_n]); \ - regArray##_2 = _mm_loadu_si128(&a[(idx_m+2)*lda + idx_n]); \ - regArray##_3 = _mm_loadu_si128(&a[(idx_m+3)*lda + idx_n]); \ - regArray##_4 = _mm_loadu_si128(&a[(idx_m+4)*lda + idx_n]); \ - regArray##_5 = _mm_loadu_si128(&a[(idx_m+5)*lda + idx_n]); \ - regArray##_6 = _mm_loadu_si128(&a[(idx_m+6)*lda + idx_n]); \ - regArray##_7 = _mm_loadu_si128(&a[(idx_m+7)*lda + idx_n]); + regArray##_0 = _mm_loadu_si128((__m128i *)(&a[(idx_m+0)*lda + idx_n])); \ + regArray##_1 = _mm_loadu_si128((__m128i *)(&a[(idx_m+1)*lda + idx_n])); \ + regArray##_2 = _mm_loadu_si128((__m128i *)(&a[(idx_m+2)*lda + idx_n])); \ + regArray##_3 = _mm_loadu_si128((__m128i *)(&a[(idx_m+3)*lda + idx_n])); \ + regArray##_4 = _mm_loadu_si128((__m128i *)(&a[(idx_m+4)*lda + idx_n])); \ + regArray##_5 = _mm_loadu_si128((__m128i *)(&a[(idx_m+5)*lda + idx_n])); \ + regArray##_6 = _mm_loadu_si128((__m128i *)(&a[(idx_m+6)*lda + idx_n])); \ + regArray##_7 = _mm_loadu_si128((__m128i *)(&a[(idx_m+7)*lda + idx_n])); #define BF16_MATRIX_LOAD_1x32(regArray, a, lda, idx_m, idx_n) \ @@ -143,11 +153,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define BF16_VECTOR_LOAD_1x16(reg, x, idx_n) \ - reg = _mm256_loadu_si256(x + idx_n); + reg = _mm256_loadu_si256((__m256i *)(x + idx_n)); #define BF16_VECTOR_LOAD_1x8(reg, x, idx_n) \ - reg = _mm_loadu_si128(x + idx_n); + reg = _mm_loadu_si128((__m128i *)(x + idx_n)); #define BF16_VECTOR_MASKZ_LOAD_1x32(reg, x, idx_n, mask) \ @@ -721,6 +731,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. _mm_mask_storeu_ps(targetAddr, mask, regResult); +/* Store 16 (result + y) to y +*/ +#define STORE16_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \ + regResult = _mm512_add_ps(regResult, _mm512_loadu_ps(targetAddr)); \ + _mm512_storeu_ps(targetAddr, regResult); + + +/* Masked store 16 (result + y) to y +*/ +#define STORE16_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \ + regResult = _mm512_add_ps(regResult, _mm512_maskz_loadu_ps(mask, targetAddr)); \ + _mm512_mask_storeu_ps(targetAddr, mask, regResult); + + +/* Store 8 (result + y) to y +*/ +#define STORE8_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \ + regResult = _mm256_add_ps(regResult, _mm256_loadu_ps(targetAddr)); \ + _mm256_storeu_ps(targetAddr, regResult); + + +/* Masked store 8 (result + y) to y +*/ +#define STORE8_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \ + regResult = _mm256_add_ps(regResult, _mm256_maskz_loadu_ps(mask, targetAddr)); \ + _mm256_mask_storeu_ps(targetAddr, mask, regResult); + + +/* Store 4 (result + y) to y +*/ +#define STORE4_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr) \ + regResult = _mm_add_ps(regResult, _mm_loadu_ps(targetAddr)); \ + _mm_storeu_ps(targetAddr, regResult); + + +/* Masked store 4 (result + y) to y +*/ +#define STORE4_MASK_COMPLETE_RESULT_ONE_ONE(regResult, targetAddr, mask) \ + regResult = _mm_add_ps(regResult, _mm_maskz_loadu_ps(mask, targetAddr)); \ + _mm_mask_storeu_ps(targetAddr, mask, regResult); + + /* Store 16 (alpha * result) to y */ #define STORE16_COMPLETE_RESULT_ALPHA(regResult, targetAddr) \ diff --git a/kernel/x86_64/casum_microk_skylakex-2.c b/kernel/x86_64/casum_microk_skylakex-2.c index d51929f9f..b398aa6e1 100644 --- a/kernel/x86_64/casum_microk_skylakex-2.c +++ b/kernel/x86_64/casum_microk_skylakex-2.c @@ -15,7 +15,7 @@ static FLOAT casum_kernel(BLASLONG n, FLOAT *x) if (n2 < 64) { __m128 accum_10, accum_11, accum_12, accum_13; - __m128 abs_mask1; + __m128 abs_mask1 = abs_mask1; accum_10 = _mm_setzero_ps(); accum_11 = _mm_setzero_ps(); diff --git a/kernel/x86_64/dasum_microk_haswell-2.c b/kernel/x86_64/dasum_microk_haswell-2.c index 4fc73ddd4..fd9da7ebe 100644 --- a/kernel/x86_64/dasum_microk_haswell-2.c +++ b/kernel/x86_64/dasum_microk_haswell-2.c @@ -38,10 +38,10 @@ static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1) __m256i abs_mask = _mm256_set1_epi64x(0x7fffffffffffffff); for (i = 0; i < tail_index_AVX2; i += 16) { - accum_0 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 0]), abs_mask); - accum_1 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 4]), abs_mask); - accum_2 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+ 8]), abs_mask); - accum_3 += (__m256d)_mm256_and_si256(_mm256_load_si256(&x1[i+12]), abs_mask); + accum_0 += (__m256d)_mm256_and_si256(_mm256_load_si256((__m256i*)&x1[i+ 0]), abs_mask); + accum_1 += (__m256d)_mm256_and_si256(_mm256_load_si256((__m256i*)&x1[i+ 4]), abs_mask); + accum_2 += (__m256d)_mm256_and_si256(_mm256_load_si256((__m256i*)&x1[i+ 8]), abs_mask); + accum_3 += (__m256d)_mm256_and_si256(_mm256_load_si256((__m256i*)&x1[i+12]), abs_mask); } accum_0 = accum_0 + accum_1 + accum_2 + accum_3; @@ -63,10 +63,10 @@ static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1) __m128i abs_mask2 = _mm_set1_epi64x(0x7fffffffffffffff); for (i = tail_index_AVX2; i < tail_index_SSE; i += 8) { - accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2); - accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 2]), abs_mask2); - accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2); - accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 6]), abs_mask2); + accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 0]), abs_mask2); + accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 2]), abs_mask2); + accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 4]), abs_mask2); + accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 6]), abs_mask2); } accum_20 = accum_20 + accum_21 + accum_22 + accum_23; diff --git a/kernel/x86_64/dasum_microk_skylakex-2.c b/kernel/x86_64/dasum_microk_skylakex-2.c index aea8c02d9..83bc078b3 100644 --- a/kernel/x86_64/dasum_microk_skylakex-2.c +++ b/kernel/x86_64/dasum_microk_skylakex-2.c @@ -58,10 +58,10 @@ static FLOAT dasum_kernel(BLASLONG n, FLOAT *x1) __m128i abs_mask2 = _mm_set1_epi64x(0x7fffffffffffffff); for (i = tail_index_AVX512; i < tail_index_SSE; i += 8) { - accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2); - accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 2]), abs_mask2); - accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2); - accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128(&x1[i + 6]), abs_mask2); + accum_20 += (__m128d)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 0]), abs_mask2); + accum_21 += (__m128d)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 2]), abs_mask2); + accum_22 += (__m128d)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 4]), abs_mask2); + accum_23 += (__m128d)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 6]), abs_mask2); } accum_20 = accum_20 + accum_21 + accum_22 + accum_23; diff --git a/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c new file mode 100644 index 000000000..d9b380fff --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_nn_skylakex.c @@ -0,0 +1,590 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() +#define LOAD_A_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[lda * k + i + (M*8)]) +#define MASK_LOAD_A_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[lda * k + i + (M*8)]) +#define BROADCAST_LOAD_B_512(M, N) __m512d Bval##N = _mm512_broadcastsd_pd(_mm_load_pd1(&B[k + ldb * (j+N)])) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) +#if defined(B0) +#define STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) +#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) +#else +#define STORE_512(M, N) \ + result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + asm("vfmadd231pd (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512)); \ + _mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) +#define MASK_STORE_512(M, N) \ + result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \ + _mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) +#endif + +#define LOAD_KA_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&mbuf[(mi + M)*K + k]); +#define LOAD_KB_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[(j + N)*ldb + k]) +#define MASK_LOAD_KA_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &mbuf[(mi + M)*K + k]) +#define MASK_LOAD_KB_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[(j + N)*ldb + k]) +#define REDUCE_4(rr0, rr1, rr2, rr3) \ + __m512d r0, r1, r2, r3, t0, t1, t2, t3;\ + r0 = _mm512_unpacklo_pd(rr0, rr1); r1 = _mm512_unpackhi_pd(rr0, rr1); \ + r2 = _mm512_unpacklo_pd(rr2, rr3); r3 = _mm512_unpackhi_pd(rr2, rr3); \ + t0 = _mm512_permutex2var_pd(r0, idx_lo, r2); t1 = _mm512_permutex2var_pd(r1, idx_lo, r3); \ + t2 = _mm512_permutex2var_pd(r0, idx_hi, r2); t3 = _mm512_permutex2var_pd(r1, idx_hi, r3); \ + r0 = _mm512_add_pd(t0, t1); r1 = _mm512_add_pd(t2, t3); t0 = _mm512_add_pd(r0, r1); \ + __m256d s0, s1; \ + s0 = _mm512_extractf64x4_pd(t0, 0); s1 = _mm512_extractf64x4_pd(t0, 1); \ + s0 = _mm256_add_pd(s0, s1); s0 = _mm256_mul_pd(alpha_256, s0); +#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) +#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) +#if defined(B0) +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N); +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + _mm256_storeu_pd(&C[(j + N)*ldc + i], s0); \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); \ +} +#else +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) + beta * C[(j+N)*ldc + i + M]; +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + asm("vfmadd231pd (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_256)); \ + _mm256_storeu_pd(&C[(j + N)*ldc + i], s0); \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + s1 = _mm256_i64gather_pd(&C[j*ldc + i + M], vindex_n, 8); \ + s0 = _mm256_fmadd_pd(s1, beta_256, s0); \ + _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); \ +} +#endif + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m32 = M & ~31; + BLASLONG m16 = M & ~15; + BLASLONG m8 = M & ~7; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n6 = N - (N % 6); + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + + __m512d alpha_512 = _mm512_broadcastsd_pd(_mm_load_pd1(&alpha)); +#if !defined(B0) + __m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_pd1(&beta)); +#endif + + for (i = 0; i < m32; i += 32) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); + STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + } + } + for (; i < m16; i += 16) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + MATMUL_512(0, 4); MATMUL_512(1, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + STORE_512(0, 4); STORE_512(1, 4); + STORE_512(0, 5); STORE_512(1, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_512(0, 0); STORE_512(1, 0); + } + } + for (; i < m8; i += 8) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + STORE_512(0, 4); + STORE_512(0, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_512(0, 0); + STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + STORE_512(0, 0); + } + } + int mm = M - i; + if (!mm) return 0; + if (mm > 4 || K < 16) { + register __mmask8 mask asm("k1") = (1UL << mm) - 1; + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); + MASK_STORE_512(0, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_STORE_512(0, 0); + } + } else { + /* M => [1, 4] + * + * This kernel use dot-like style to calc a value - C(x, y): + * C(x, y) = A(x, 0)*B(0, y) + A(x, 1)*B(1, y) +....+ A(x, K)*B(K, y) + * + * Alloc a buf to copy rest of A as row major, + * so memory access from 0 to K is continuous for both A & B. + * + * Loading to zmm and FMA 8 of k at one loop, + * finally reduce_add zmm to a single float result in C(x, y). + * + * Note: performance is bad when K is small. + */ + FLOAT *mbuf = (FLOAT *) malloc(sizeof(FLOAT)*mm*K); + __mmask8 mask = (1UL << mm) - 1; + BLASLONG k8 = K & ~7; + BLASLONG k4 = K & ~3; + for (k = 0; k < k4; k += 4) { + __m256d r0, r1, r2, r3; + __m256d t0, t1, t2, t3; + r0 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(0 + k)]); + r1 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(1 + k)]); + r2 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(2 + k)]); + r3 = _mm256_maskz_loadu_pd(mask, &A[i + lda*(3 + k)]); + + t0 = _mm256_unpacklo_pd(r0, r1); + t1 = _mm256_unpackhi_pd(r0, r1); + t2 = _mm256_unpacklo_pd(r2, r3); + t3 = _mm256_unpackhi_pd(r2, r3); + + r0 = _mm256_permute2f128_pd(t0, t2, 0x20); + r1 = _mm256_permute2f128_pd(t1, t3, 0x20); + r2 = _mm256_permute2f128_pd(t0, t2, 0x31); + r3 = _mm256_permute2f128_pd(t1, t3, 0x31); + + switch (mm) { + case 4: _mm256_storeu_pd(&mbuf[k + 3*K], r3); + case 3: _mm256_storeu_pd(&mbuf[k + 2*K], r2); + case 2: _mm256_storeu_pd(&mbuf[k + 1*K], r1); + case 1: _mm256_storeu_pd(&mbuf[k + 0*K], r0); + } + } + for (; k < K; k++) { + for (int ii = 0; ii < mm; ii++) { + mbuf[k + ii*K] = A[i + lda*k + ii]; + } + } + int mi = 0; + __m256d alpha_256 = _mm256_broadcast_sd(&alpha); +#if !defined(B0) + __m256d beta_256 = _mm256_broadcast_sd(&beta); +#endif + __m256i vindex_n = _mm256_set_epi64x(ldc*3, ldc*2, ldc*1, 0); + long long permute_table[] = { + 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, + 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, + }; + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 8); + for (; i < m4; i += 4, mi += 4) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_REDUCE_M4(0); + } + + } + for (; i < m2; i += 2, mi += 2) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + } + } + for (; i < M; i += 1, mi += 1) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_REDUCE_N4(0); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_REDUCE(0, 0); + STORE_REDUCE(0, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + STORE_REDUCE(0, 0); + } + } + free(mbuf); + } + return 0; +} diff --git a/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c new file mode 100644 index 000000000..e757197ba --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_nt_skylakex.c @@ -0,0 +1,535 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() +#define LOAD_A_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[lda * k + i + (M*8)]) +#define MASK_LOAD_A_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[lda * k + i + (M*8)]) +#define BROADCAST_LOAD_B_512(M, N) __m512d Bval##N = _mm512_broadcastsd_pd(_mm_load_sd(&B[ldb * k + j + N])) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) + +#define BROADCAST_LOAD_A_512(M, N) __m512d Aval##M = _mm512_broadcastsd_pd(_mm_load_sd(&A[lda * k + i + M])) +#define LOAD_B_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[ldb * k + j + (N*8)]) +#define MASK_LOAD_B_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[ldb * k + j + (N*8)]) +#if defined(B0) +#define STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) +#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8) +#else +#define STORE_512(M, N) \ + result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + asm("vfmadd231pd (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512)); \ + _mm512_storeu_pd(&C[(j+N)*ldc + i + (M*8)], result##M##N) +#define MASK_STORE_512(M, N) \ + result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + asm("vfmadd231pd (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*8)]), "v"(beta_512), "k"(mask)); \ + _mm512_mask_storeu_pd(&C[(j+N)*ldc + i + (M*8)], mask, result##M##N) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + __m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ + result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ + _mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + __m512d tmp##M##N = _mm512_mask_i64gather_pd(_mm512_setzero_pd(), mask, vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ + result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ + _mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8); +#endif + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m32 = M & ~31; + BLASLONG m16 = M & ~15; + BLASLONG m8 = M & ~7; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n32 = N & ~31; + BLASLONG n16 = N & ~15; + BLASLONG n8 = N & ~7; + BLASLONG n6 = N - (N % 6); + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + + __m512d alpha_512 = _mm512_broadcastsd_pd(_mm_load_sd(&alpha)); +#if !defined(B0) + __m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); +#endif + + for (i = 0; i < m32; i += 32) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); DECLARE_RESULT_512(2, 4); DECLARE_RESULT_512(3, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); DECLARE_RESULT_512(2, 5); DECLARE_RESULT_512(3, 5); + + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + MATMUL_512(0, 4); MATMUL_512(1, 4); MATMUL_512(2, 4); MATMUL_512(3, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); MATMUL_512(2, 5); MATMUL_512(3, 5); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); + STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); + STORE_512(0, 4); STORE_512(1, 4); STORE_512(2, 4); STORE_512(3, 4); + STORE_512(0, 5); STORE_512(1, 5); STORE_512(2, 5); STORE_512(3, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + } + } + for (; i < m16; i += 16) { + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); + DECLARE_RESULT_512(0, 6); DECLARE_RESULT_512(1, 6); + DECLARE_RESULT_512(0, 7); DECLARE_RESULT_512(1, 7); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + MATMUL_512(0, 4); MATMUL_512(1, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); + MATMUL_512(0, 6); MATMUL_512(1, 6); + MATMUL_512(0, 7); MATMUL_512(1, 7); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + STORE_512(0, 4); STORE_512(1, 4); + STORE_512(0, 5); STORE_512(1, 5); + STORE_512(0, 6); STORE_512(1, 6); + STORE_512(0, 7); STORE_512(1, 7); + } + for (;j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_512(0, 0); STORE_512(1, 0); + } + } + for (; i < m8; i += 8) { + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + DECLARE_RESULT_512(0, 6); + DECLARE_RESULT_512(0, 7); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + MATMUL_512(0, 6); + MATMUL_512(0, 7); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + STORE_512(0, 4); + STORE_512(0, 5); + STORE_512(0, 6); + STORE_512(0, 7); + } + for (; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + } + + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_512(0, 0); + STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + STORE_512(0, 0); + } + } + int mm = M - i; + if (mm >= 6) { + register __mmask16 mask asm("k1") = (1UL << mm) - 1; + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + DECLARE_RESULT_512(0, 6); + DECLARE_RESULT_512(0, 7); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + MATMUL_512(0, 6); + MATMUL_512(0, 7); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); + MASK_STORE_512(0, 5); + MASK_STORE_512(0, 6); + MASK_STORE_512(0, 7); + } + for (; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + } + + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_STORE_512(0, 0); + } + } else if (mm > 0) { + long long index_n[8]; + for (int ii = 0; ii < 8; ii++) { + index_n[ii] = ii * ldc; + } + __m512i vindex_n = _mm512_loadu_si512(index_n); + for (; i < m4; i += 4) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); SCATTER_STORE_512(2, 2); SCATTER_STORE_512(3, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); SCATTER_STORE_512(2, 3); SCATTER_STORE_512(3, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); + } + __mmask8 mask = 0xff; + for (; j < N; j += 8) { + int remains = N - j; + if (remains < 8) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); MASK_SCATTER_STORE_512(2, 0); MASK_SCATTER_STORE_512(3, 0); + } + } + for (; i < m2; i += 2) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + } + __mmask8 mask = 0xff; + for (; j < N; j += 8) { + int remains = N - j; + if (remains < 8) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + SCATTER_STORE_512(0, 2); + SCATTER_STORE_512(0, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + } + __mmask8 mask = 0xff; + for (; j < N; j += 8) { + int remains = N - j; + if (remains < 8) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_SCATTER_STORE_512(0, 0); + } + } + } + return 0; +} diff --git a/kernel/x86_64/dgemm_small_kernel_permit_skylakex.c b/kernel/x86_64/dgemm_small_kernel_permit_skylakex.c new file mode 100644 index 000000000..9cca08e71 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_permit_skylakex.c @@ -0,0 +1,44 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) +{ + double MNK = (double) M * (double) N * (double) K; + if (MNK > 100.0*100.0*100.0) // disable for big size matrix + return 0; + if (transa && !transb) { + /* TN kernel perform not good when: + * 1. C matrix is too big + * 2. K is too small + */ + if (M * N > 1200 || K < 32) + return 0; + } + return 1; +} diff --git a/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c b/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c new file mode 100644 index 000000000..18c797283 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_tn_skylakex.c @@ -0,0 +1,322 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) + +#define LOAD_KA_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[(i + M)*lda + k]); +#define LOAD_KB_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[(j + N)*ldb + k]) +#define MASK_LOAD_KA_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[(i + M)*lda + k]) +#define MASK_LOAD_KB_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[(j + N)*ldb + k]) + +#define REDUCE_4(rr0, rr1, rr2, rr3) \ + __m512d r0, r1, r2, r3, t0, t1, t2, t3;\ + r0 = _mm512_unpacklo_pd(rr0, rr1); r1 = _mm512_unpackhi_pd(rr0, rr1); \ + r2 = _mm512_unpacklo_pd(rr2, rr3); r3 = _mm512_unpackhi_pd(rr2, rr3); \ + t0 = _mm512_permutex2var_pd(r0, idx_lo, r2); t1 = _mm512_permutex2var_pd(r1, idx_lo, r3); \ + t2 = _mm512_permutex2var_pd(r0, idx_hi, r2); t3 = _mm512_permutex2var_pd(r1, idx_hi, r3); \ + r0 = _mm512_add_pd(t0, t1); r1 = _mm512_add_pd(t2, t3); t0 = _mm512_add_pd(r0, r1); \ + __m256d s0, s1; \ + s0 = _mm512_extractf64x4_pd(t0, 0); s1 = _mm512_extractf64x4_pd(t0, 1); \ + s0 = _mm256_add_pd(s0, s1); s0 = _mm256_mul_pd(alpha_256, s0); + +#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) +#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) + +#if defined(B0) +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) +#define STORE_M4(N, s0) _mm256_storeu_pd(&C[(j + N)*ldc + i], s0); +#define STORE_N4(M, s0) _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); +#else +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) + beta * C[(j+N)*ldc + i + M] +#define STORE_M4(N, s0) \ + asm("vfmadd231pd (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_256)); \ + _mm256_storeu_pd(&C[(j + N)*ldc + i], s0); + +#define STORE_N4(M, s0) \ + s0 = _mm256_fmadd_pd(_mm256_i64gather_pd(&C[j*ldc + i + M], vindex_n, 8), beta_256, s0); \ + _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8); +#endif +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + STORE_M4(N, s0) \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + STORE_N4(M, s0) \ +} + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + BLASLONG k8 = K & ~7; + + __mmask8 mask; + + __m256i vindex_n = _mm256_set_epi64x(ldc*3, ldc*2, ldc, 0); + __m256d alpha_256 = _mm256_broadcast_sd(&alpha); +#if !defined(B0) + __m256d beta_256 = _mm256_broadcast_sd(&beta); +#endif + + long long permute_table[] = { + 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, + 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, + }; + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 8); + + for (i = 0; i < m4; i += 4) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_REDUCE_M4(0); + } + + } + for (; i < m2; i += 2) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_REDUCE_N4(0); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_REDUCE(0, 0); + STORE_REDUCE(0, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < k8; k += 8) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + STORE_REDUCE(0, 0); + } + } + return 0; +} diff --git a/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c b/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c new file mode 100644 index 000000000..00f42aa76 --- /dev/null +++ b/kernel/x86_64/dgemm_small_kernel_tt_skylakex.c @@ -0,0 +1,392 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include + +#define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd() +#define BROADCAST_LOAD_A_512(M, N) __m512d Aval##M = _mm512_broadcastsd_pd(_mm_load_sd(&A[k + lda * (i+M)])) +#define LOAD_B_512(M,N) __m512d Bval##N = _mm512_loadu_pd(&B[ldb * k + j + (N*8)]) +#define MASK_LOAD_B_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[ldb * k + j + (N*8)]) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N) + +#if defined(B0) +#define STORE_8xy(v, N, x, y) _mm512_storeu_pd(&C[(j + N*8 + x + y*8)*ldc + i], v) +#define STORE_4xy(v, N, x, y) _mm256_storeu_pd(&C[(j + N*8 + x + y*4)*ldc + i], v) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + _mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8); +#else +#define STORE_8xy(v, N, x, y) \ + asm("vfmadd231pd (%1), %2, %0": "+v"(v): "r"(&C[(j + N*8 + x + y*8)*ldc + i]), "v"(beta_512)); \ + _mm512_storeu_pd(&C[(j + N*8 + x + y*8)*ldc + i], v) +#define STORE_4xy(v, N, x, y) \ + asm("vfmadd231pd (%1), %2, %0": "+v"(v): "r"(&C[(j + N*8 + x + y*4)*ldc + i]), "v"(beta_256)); \ + _mm256_storeu_pd(&C[(j + N*8 + x + y*4)*ldc + i], v) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + __m512d tmp##M##N = _mm512_i64gather_pd(vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ + result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ + _mm512_i64scatter_pd(&C[(j + N*8)*ldc + i + M], vindex_n, result##M##N, 8); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_pd(result##M##N, alpha_512); \ + __m512d tmp##M##N = _mm512_mask_i64gather_pd(_mm512_setzero_pd(), mask, vindex_n, &C[(j + N*8)*ldc + i + M], 8); \ + result##M##N = _mm512_fmadd_pd(tmp##M##N, beta_512, result##M##N); \ + _mm512_mask_i64scatter_pd(&C[(j + N*8)*ldc + i + M], mask, vindex_n, result##M##N, 8); +#endif + +#define REORDER_8x8(r0, r1, r2, r3, r4, r5, r6, r7) \ + __m512d t0, t1, t2, t3, t4, t5, t6, t7; \ + t0 = _mm512_unpacklo_pd(r0, r1); \ + t1 = _mm512_unpackhi_pd(r0, r1); \ + t2 = _mm512_unpacklo_pd(r2, r3); \ + t3 = _mm512_unpackhi_pd(r2, r3); \ + t4 = _mm512_unpacklo_pd(r4, r5); \ + t5 = _mm512_unpackhi_pd(r4, r5); \ + t6 = _mm512_unpacklo_pd(r6, r7); \ + t7 = _mm512_unpackhi_pd(r6, r7); \ + r0 = _mm512_shuffle_f64x2(t0, t2, 0x88); \ + r1 = _mm512_shuffle_f64x2(t1, t3, 0x88); \ + r2 = _mm512_shuffle_f64x2(t0, t2, 0xdd); \ + r3 = _mm512_shuffle_f64x2(t1, t3, 0xdd); \ + r4 = _mm512_shuffle_f64x2(t4, t6, 0x88); \ + r5 = _mm512_shuffle_f64x2(t5, t7, 0x88); \ + r6 = _mm512_shuffle_f64x2(t4, t6, 0xdd); \ + r7 = _mm512_shuffle_f64x2(t5, t7, 0xdd); \ + t0 = _mm512_permutex2var_pd(r0, idx_lo, r4); \ + t1 = _mm512_permutex2var_pd(r1, idx_lo, r5); \ + t2 = _mm512_permutex2var_pd(r2, idx_lo, r6); \ + t3 = _mm512_permutex2var_pd(r3, idx_lo, r7); \ + t4 = _mm512_permutex2var_pd(r0, idx_hi, r4); \ + t5 = _mm512_permutex2var_pd(r1, idx_hi, r5); \ + t6 = _mm512_permutex2var_pd(r2, idx_hi, r6); \ + t7 = _mm512_permutex2var_pd(r3, idx_hi, r7); \ + t0 = _mm512_mul_pd(t0, alpha_512); \ + t1 = _mm512_mul_pd(t1, alpha_512); \ + t2 = _mm512_mul_pd(t2, alpha_512); \ + t3 = _mm512_mul_pd(t3, alpha_512); \ + t4 = _mm512_mul_pd(t4, alpha_512); \ + t5 = _mm512_mul_pd(t5, alpha_512); \ + t6 = _mm512_mul_pd(t6, alpha_512); \ + t7 = _mm512_mul_pd(t7, alpha_512); + +#define SAVE_8(N, x) {\ + STORE_8xy(t##x, N, x, 0); \ +} + +#define REORDER_STORE_8x8(N) {\ + REORDER_8x8(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ + SAVE_8(N, 0); SAVE_8(N, 1); SAVE_8(N, 2); SAVE_8(N, 3); SAVE_8(N, 4); SAVE_8(N, 5); SAVE_8(N, 6); SAVE_8(N, 7); \ +} + +#define MASK_SAVE_8() \ + switch (nn) { \ + case 8: SAVE_8(0, 7); \ + case 7: SAVE_8(0, 6); \ + case 6: SAVE_8(0, 5); \ + case 5: SAVE_8(0, 4); \ + case 4: SAVE_8(0, 3); \ + case 3: SAVE_8(0, 2); \ + case 2: SAVE_8(0, 1); \ + case 1: SAVE_8(0, 0); \ + } + +#define MASK_REORDER_STORE_8x8(N) {\ + REORDER_8x8(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ + MASK_SAVE_8(); \ +} + +#define REORDER_4x8(r0, r1, r2, r3) \ + __m512d t0, t1, t2, t3; \ + t0 = _mm512_unpacklo_pd(r0, r1); \ + t1 = _mm512_unpackhi_pd(r0, r1); \ + t2 = _mm512_unpacklo_pd(r2, r3); \ + t3 = _mm512_unpackhi_pd(r2, r3); \ + r0 = _mm512_permutex2var_pd(t0, idx_lo, t2); \ + r1 = _mm512_permutex2var_pd(t1, idx_lo, t3); \ + r2 = _mm512_permutex2var_pd(t0, idx_hi, t2); \ + r3 = _mm512_permutex2var_pd(t1, idx_hi, t3); \ + t0 = _mm512_mul_pd(r0, alpha_512); \ + t1 = _mm512_mul_pd(r1, alpha_512); \ + t2 = _mm512_mul_pd(r2, alpha_512); \ + t3 = _mm512_mul_pd(r3, alpha_512); + +#define SAVE_4(N, x, y) {\ + __m256d v4 = _mm512_extractf64x4_pd(t##x, y); \ + STORE_4xy(v4, N, x, y); \ +} + +#define REORDER_STORE_4x8(N) {\ + REORDER_4x8(result0##N, result1##N, result2##N, result3##N); \ + SAVE_4(N, 0, 0); SAVE_4(N, 1, 0); SAVE_4(N, 2, 0); SAVE_4(N, 3, 0); \ + SAVE_4(N, 0, 1); SAVE_4(N, 1, 1); SAVE_4(N, 2, 1); SAVE_4(N, 3, 1); \ +} + +#define MASK_SAVE_4() \ + switch (nn) { \ + case 8: SAVE_4(0, 3, 1); \ + case 7: SAVE_4(0, 2, 1); \ + case 6: SAVE_4(0, 1, 1); \ + case 5: SAVE_4(0, 0, 1); \ + case 4: SAVE_4(0, 3, 0); \ + case 3: SAVE_4(0, 2, 0); \ + case 2: SAVE_4(0, 1, 0); \ + case 1: SAVE_4(0, 0, 0); \ + } + +#define MASK_REORDER_STORE_4x8(N) {\ + REORDER_4x8(result0##N, result1##N, result2##N, result3##N); \ + MASK_SAVE_4(); \ +} + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m8 = M & ~7; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n32 = N & ~31; + BLASLONG n16 = N & ~15; + + __m512d alpha_512 = _mm512_broadcastsd_pd(_mm_load_sd(&alpha)); +#if !defined(B0) + __m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); + __m256d beta_256 = _mm256_broadcastsd_pd(_mm_load_sd(&beta)); +#endif + long long permute_table[] = { + 0, 1, 4, 5, 0|8, 1|8, 4|8, 5|8, + 2, 3, 6, 7, 2|8, 3|8, 6|8, 7|8, + }; + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 8); + + for (i = 0; i < m8; i += 8) { + for (j = 0; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); + + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(4, 1); DECLARE_RESULT_512(5, 1); DECLARE_RESULT_512(6, 1); DECLARE_RESULT_512(7, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(4, 1); MATMUL_512(5, 1); MATMUL_512(6, 1); MATMUL_512(7, 1); + } + REORDER_STORE_8x8(0); + REORDER_STORE_8x8(1); + } + __mmask8 mask = 0xff; + int nn = 8; + for (; j < N; j += 8) { + if (N - j < 8) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); + } + MASK_REORDER_STORE_8x8(0); + } + } + for (; i < m4; i += 4) { + long long permute_table2[] = { + 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8, + 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8, + }; + idx_lo = _mm512_loadu_si512(permute_table2); + idx_hi = _mm512_loadu_si512(permute_table2 + 8); + + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + REORDER_STORE_4x8(0); + REORDER_STORE_4x8(1); + REORDER_STORE_4x8(2); + REORDER_STORE_4x8(3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + REORDER_STORE_4x8(0); + REORDER_STORE_4x8(1); + } + __mmask8 mask = 0xff; + int nn = 8; + for (; j < N; j += 8) { + if (N - j < 8) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + MASK_REORDER_STORE_4x8(0); + } + } + if (i < M) { + long long index_n[8]; + for (int ii = 0; ii < 8; ii++) { + index_n[ii] = ii * ldc; + } + __m512i vindex_n = _mm512_loadu_si512(index_n); +#if !defined(B0) + __m512d beta_512 = _mm512_broadcastsd_pd(_mm_load_sd(&beta)); +#endif + for (; i < m2; i += 2) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + } + __mmask8 mask = 0xff; + int nn = 8; + for (; j < N; j += 8) { + if (N - j < 8) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + SCATTER_STORE_512(0, 2); + SCATTER_STORE_512(0, 3); + } + for (; j < n16; j += 16) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + } + __mmask8 mask = 0xff; + int nn = 8; + for (; j < N; j += 8) { + if (N - j < 8) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_SCATTER_STORE_512(0, 0); + } + } + } + return 0; +} diff --git a/kernel/x86_64/sasum_microk_haswell-2.c b/kernel/x86_64/sasum_microk_haswell-2.c index 8e6cb9a47..2eb5b9538 100644 --- a/kernel/x86_64/sasum_microk_haswell-2.c +++ b/kernel/x86_64/sasum_microk_haswell-2.c @@ -38,10 +38,10 @@ static FLOAT sasum_kernel(BLASLONG n, FLOAT *x1) __m256i abs_mask = _mm256_set1_epi32(0x7fffffff); for (i = 0; i < tail_index_AVX2; i += 32) { - accum_0 += (__m256)_mm256_and_si256(_mm256_load_si256(&x1[i+ 0]), abs_mask); - accum_1 += (__m256)_mm256_and_si256(_mm256_load_si256(&x1[i+ 8]), abs_mask); - accum_2 += (__m256)_mm256_and_si256(_mm256_load_si256(&x1[i+16]), abs_mask); - accum_3 += (__m256)_mm256_and_si256(_mm256_load_si256(&x1[i+24]), abs_mask); + accum_0 += (__m256)_mm256_and_si256(_mm256_load_si256((__m256i*)&x1[i+ 0]), abs_mask); + accum_1 += (__m256)_mm256_and_si256(_mm256_load_si256((__m256i*)&x1[i+ 8]), abs_mask); + accum_2 += (__m256)_mm256_and_si256(_mm256_load_si256((__m256i*)&x1[i+16]), abs_mask); + accum_3 += (__m256)_mm256_and_si256(_mm256_load_si256((__m256i*)&x1[i+24]), abs_mask); } accum_0 = accum_0 + accum_1 + accum_2 + accum_3; @@ -62,8 +62,8 @@ static FLOAT sasum_kernel(BLASLONG n, FLOAT *x1) __m128i abs_mask2 = _mm_set1_epi32(0x7fffffff); for (i = tail_index_AVX2; i < tail_index_SSE; i += 8) { - accum_20 += (__m128)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2); - accum_21 += (__m128)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2); + accum_20 += (__m128)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 0]), abs_mask2); + accum_21 += (__m128)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 4]), abs_mask2); } accum_20 += accum_21; diff --git a/kernel/x86_64/sasum_microk_skylakex-2.c b/kernel/x86_64/sasum_microk_skylakex-2.c index c8c69d1e0..fbc91b558 100644 --- a/kernel/x86_64/sasum_microk_skylakex-2.c +++ b/kernel/x86_64/sasum_microk_skylakex-2.c @@ -53,8 +53,8 @@ static FLOAT sasum_kernel(BLASLONG n, FLOAT *x1) __m128i abs_mask2 = _mm_set1_epi32(0x7fffffff); for (i = tail_index_AVX512; i < tail_index_SSE; i += 8) { - accum_20 += (__m128)_mm_and_si128(_mm_loadu_si128(&x1[i + 0]), abs_mask2); - accum_21 += (__m128)_mm_and_si128(_mm_loadu_si128(&x1[i + 4]), abs_mask2); + accum_20 += (__m128)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 0]), abs_mask2); + accum_21 += (__m128)_mm_and_si128(_mm_loadu_si128((__m128i*)&x1[i + 4]), abs_mask2); } accum_20 += accum_21; diff --git a/kernel/x86_64/sbdot_microk_cooperlake.c b/kernel/x86_64/sbdot_microk_cooperlake.c index 067726cb1..2aefe46ff 100644 --- a/kernel/x86_64/sbdot_microk_cooperlake.c +++ b/kernel/x86_64/sbdot_microk_cooperlake.c @@ -79,21 +79,21 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y) __m256 accum256_1 = _mm256_setzero_ps(); int tail_index_32 = n&(~31); for (int j = 0; j < tail_index_32; j += 32) { - accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[j+ 0]), (__m256bh) _mm256_loadu_si256(&y[j+ 0])); - accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256(&x[j+16]), (__m256bh) _mm256_loadu_si256(&y[j+16])); + accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[j+ 0]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[j+ 0])); + accum256_1 = _mm256_dpbf16_ps(accum256_1, (__m256bh) _mm256_loadu_si256((__m256i *)&x[j+16]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[j+16])); } accum256 = _mm256_add_ps(accum256, accum256_1); /* Processing the remaining <32 chunk with 16-elements processing */ if ((n&16) != 0) { - accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[tail_index_32]), (__m256bh) _mm256_loadu_si256(&y[tail_index_32])); + accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[tail_index_32]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[tail_index_32])); } accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); /* Processing the remaining <16 chunk with 8-elements processing */ if ((n&8) != 0) { int tail_index_16 = n&(~15); - accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); + accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[tail_index_16]), (__m128bh) _mm_loadu_si128((__m128i *)&y[tail_index_16])); } /* Processing the remaining <8 chunk with masked 8-elements processing */ @@ -108,13 +108,13 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y) } else if (n > 15) { /* n range from 16 to 31 */ /* Processing <32 chunk with 16-elements processing */ __m256 accum256 = _mm256_setzero_ps(); - accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256(&x[0]), (__m256bh) _mm256_loadu_si256(&y[0])); + accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) _mm256_loadu_si256((__m256i *)&x[0]), (__m256bh) _mm256_loadu_si256((__m256i *)&y[0])); accum128 += _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf128_ps(accum256, 1)); /* Processing the remaining <16 chunk with 8-elements processing */ if ((n&8) != 0) { int tail_index_16 = n&(~15); - accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[tail_index_16]), (__m128bh) _mm_loadu_si128(&y[tail_index_16])); + accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[tail_index_16]), (__m128bh) _mm_loadu_si128((__m128i *)&y[tail_index_16])); } /* Processing the remaining <8 chunk with masked 8-elements processing */ @@ -128,7 +128,7 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y) } } else if (n > 7) { /* n range from 8 to 15 */ /* Processing <16 chunk with 8-elements processing */ - accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128(&x[0]), (__m128bh) _mm_loadu_si128(&y[0])); + accum128 = _mm_dpbf16_ps(accum128, (__m128bh) _mm_loadu_si128((__m128i *)&x[0]), (__m128bh) _mm_loadu_si128((__m128i *)&y[0])); /* Processing the remaining <8 chunk with masked 8-elements processing */ if ((n&7) != 0) { diff --git a/kernel/x86_64/sbgemm_block_microk_cooperlake.c b/kernel/x86_64/sbgemm_block_microk_cooperlake.c index 2376fed02..b8c41f4f7 100644 --- a/kernel/x86_64/sbgemm_block_microk_cooperlake.c +++ b/kernel/x86_64/sbgemm_block_microk_cooperlake.c @@ -1,426 +1,1871 @@ -#include "sbgemm.h" - #include + // Walk around those intrinsics that missed by compiler #define MM256_LOADU_EPI16(addr) \ _mm256_maskz_loadu_epi16(~0, (addr)) #define MM256_STOREU_EPI16(addr, reg) \ _mm256_mask_storeu_epi16((addr), ~0, (reg)) -#include -void print_block(BLASLONG m, BLASLONG n, bfloat16 * mat) -{ - printf("---- BLOCK %ld x %ld ----\n", m, n); - for (BLASLONG i=0; i> (32-m)); __m512i array512_0, array512_1, array512_2, array512_3; - BLASLONG idx_src_base0, idx_src_base1; - BLASLONG idx_target_base0, idx_target_base1; + bfloat16 * src_addr0, * src_addr1; + bfloat16 * dst_addr0, * dst_addr1; BLASLONG LDA_2x = 2*lda; BLASLONG BF16_BLOCK_T_M_2x = 2*32; - idx_src_base0 = 0; - idx_src_base1 = lda; - idx_target_base0 = 0; - idx_target_base1 = 32; + + src_addr0 = A; + src_addr1 = A + lda; + dst_addr0 = block_A; + dst_addr1 = block_A + 32; + for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { - array512_0 = _mm512_loadu_si512(&A[idx_src_base0]); - array512_1 = _mm512_loadu_si512(&A[idx_src_base1]); + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0); + array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1); array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1); array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1); - _mm512_storeu_si512(&block_A[idx_target_base0], array512_2); - _mm512_storeu_si512(&block_A[idx_target_base1], array512_3); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); - idx_src_base0 += LDA_2x; - idx_src_base1 += LDA_2x; - idx_target_base0 += BF16_BLOCK_T_M_2x; - idx_target_base1 += BF16_BLOCK_T_M_2x; + src_addr0 += LDA_2x; + src_addr1 += LDA_2x; + dst_addr0 += BF16_BLOCK_T_M_2x; + dst_addr1 += BF16_BLOCK_T_M_2x; } if (tag_k_2x != k) { __m512i ZERO512 = _mm512_setzero_si512(); - array512_0 = _mm512_loadu_si512(&A[idx_src_base0]); + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0); array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512); array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512); - _mm512_storeu_si512(&block_A[idx_target_base0], array512_2); - _mm512_storeu_si512(&block_A[idx_target_base1], array512_3); - } - -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A); -#endif -} - -void COL_MAJOR_INCOPY_KERNEL_Kx32m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) -{ - BLASLONG tag_k_2x = k & (~1); - unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-m)); - __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); - - __m512i array512_0, array512_1, array512_2, array512_3; - - BLASLONG idx_src_base0, idx_src_base1; - BLASLONG idx_target_base0, idx_target_base1; - - BLASLONG LDA_2x = 2*lda; - BLASLONG BF16_BLOCK_T_M_2x = 2*32; - idx_src_base0 = 0; - idx_src_base1 = lda; - idx_target_base0 = 0; - idx_target_base1 = 32; - for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { - array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]); - array512_1 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]); - array512_2 = _mm512_unpacklo_epi16(array512_0, array512_1); - array512_3 = _mm512_unpackhi_epi16(array512_0, array512_1); - _mm512_storeu_si512(&block_A[idx_target_base0], array512_2); - _mm512_storeu_si512(&block_A[idx_target_base1], array512_3); - - idx_src_base0 += LDA_2x; - idx_src_base1 += LDA_2x; - idx_target_base0 += BF16_BLOCK_T_M_2x; - idx_target_base1 += BF16_BLOCK_T_M_2x; - } - - if (tag_k_2x != k) { - __m512i ZERO512 = _mm512_setzero_si512(); - array512_0 = _mm512_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]); - array512_2 = _mm512_unpacklo_epi16(array512_0, ZERO512); - array512_3 = _mm512_unpackhi_epi16(array512_0, ZERO512); - _mm512_storeu_si512(&block_A[idx_target_base0], array512_2); - _mm512_storeu_si512(&block_A[idx_target_base1], array512_3); - } - -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A); -#endif + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + } } +// INCOPY Kernel, 0> (16-m)); __m256i array256_0, array256_1, array256_2, array256_3; - BLASLONG idx_src_base0, idx_src_base1; - BLASLONG idx_target_base0; + bfloat16 * src_addr0, * src_addr1; + bfloat16 * dst_addr0; BLASLONG LDA_2x = 2*lda; - idx_src_base0 = 0; - idx_src_base1 = lda; - idx_target_base0 = 0; + + src_addr0 = A; + src_addr1 = A + lda; + dst_addr0 = block_A; + for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { - array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]); - array256_1 = MM256_LOADU_EPI16(&A[idx_src_base1]); + array256_0 = _mm256_maskz_loadu_epi16(tail_mask, src_addr0); + array256_1 = _mm256_maskz_loadu_epi16(tail_mask, src_addr1); array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1); array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1); // Store in one row of block_B - MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2); - MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3); + MM256_STOREU_EPI16(dst_addr0, array256_2); + MM256_STOREU_EPI16(dst_addr0+16, array256_3); - idx_src_base0 += LDA_2x; - idx_src_base1 += LDA_2x; - idx_target_base0 += 32; + src_addr0 += LDA_2x; + src_addr1 += LDA_2x; + dst_addr0 += 32; } if (tag_k_2x != k) { __m256i ZERO256 = _mm256_setzero_si256(); - array256_0 = MM256_LOADU_EPI16(&A[idx_src_base0]); + array256_0 = _mm256_maskz_loadu_epi16(tail_mask, src_addr0); array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256); array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256); // Store in one row of block_B - MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2); - MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3); + MM256_STOREU_EPI16(dst_addr0, array256_2); + MM256_STOREU_EPI16(dst_addr0+16, array256_3); } - -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A); -#endif } -void COL_MAJOR_INCOPY_KERNEL_Kx16m(BLASLONG k, BLASLONG m, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) +// K=32, M=16 +void COL_MAJOR_ITCOPY_KERNEL_32x16(bfloat16 * A, BLASLONG lda, bfloat16 * block_A) { - BLASLONG tag_k_2x = k & (~1); - unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * dst_addr0, * dst_addr1; - __m256i array256_0, array256_1, array256_2, array256_3; + BLASLONG LDA_4x = lda*4; - BLASLONG idx_src_base0, idx_src_base1; - BLASLONG idx_target_base0; + src_addr0 = A; + src_addr1 = A + lda; + src_addr2 = A + lda*2; + src_addr3 = A + lda*3; + dst_addr0 = block_A; + dst_addr1 = block_A + 32*8; - BLASLONG LDA_2x = 2*lda; - idx_src_base0 = 0; - idx_src_base1 = lda; - idx_target_base0 = 0; - for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { - array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]); - array256_1 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base1]); - array256_2 = _mm256_unpacklo_epi16(array256_0, array256_1); - array256_3 = _mm256_unpackhi_epi16(array256_0, array256_1); - // Store in one row of block_B - MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2); - MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3); + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3; + __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3; + __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3; - idx_src_base0 += LDA_2x; - idx_src_base1 += LDA_2x; - idx_target_base0 += 32; - } + __m512i M512_EPI64_2 = _mm512_set1_epi64(2); + __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2); - if (tag_k_2x != k) { - __m256i ZERO256 = _mm256_setzero_si256(); - array256_0 = _mm256_maskz_loadu_epi16(tail_mask, &A[idx_src_base0]); - array256_2 = _mm256_unpacklo_epi16(array256_0, ZERO256); - array256_3 = _mm256_unpackhi_epi16(array256_0, ZERO256); - // Store in one row of block_B - MM256_STOREU_EPI16(&block_A[idx_target_base0], array256_2); - MM256_STOREU_EPI16(&block_A[idx_target_base0 + 16], array256_3); - } + // Load and preprocess 1st 4 rows + array512_way0_0 = _mm512_loadu_si512(src_addr0); + array512_way0_1 = _mm512_loadu_si512(src_addr1); + array512_way0_2 = _mm512_loadu_si512(src_addr2); + array512_way0_3 = _mm512_loadu_si512(src_addr3); + array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1); + array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1); + array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3); + array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3); + array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_K, BF16_BLOCK_THRES_M, block_A); -#endif + // Load and preprocess 2nd 4 rows + array512_way1_0 = _mm512_loadu_si512(src_addr0); + array512_way1_1 = _mm512_loadu_si512(src_addr1); + array512_way1_2 = _mm512_loadu_si512(src_addr2); + array512_way1_3 = _mm512_loadu_si512(src_addr3); + array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1); + array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1); + array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3); + array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3); + array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 3rd 4 rows + array512_way2_0 = _mm512_loadu_si512(src_addr0); + array512_way2_1 = _mm512_loadu_si512(src_addr1); + array512_way2_2 = _mm512_loadu_si512(src_addr2); + array512_way2_3 = _mm512_loadu_si512(src_addr3); + array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1); + array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1); + array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3); + array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3); + array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 4th 4 rows + array512_way3_0 = _mm512_loadu_si512(src_addr0); + array512_way3_1 = _mm512_loadu_si512(src_addr1); + array512_way3_2 = _mm512_loadu_si512(src_addr2); + array512_way3_3 = _mm512_loadu_si512(src_addr3); + array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1); + array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1); + array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3); + array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3); + array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Compose and store the 0/1 and 16/17 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 2/3 and 18/19 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 4/5 and 20/21 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 6/7 and 22/23 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 8/9 and 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 10/11 and 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 12/13 and 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 14/15 and 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_1, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); } +// K=Any number but will be processed based on 32, M=32 +void COL_MAJOR_ITCOPY_KERNEL_Kx32(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) +{ + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * dst_addr0, * dst_addr1; + + BLASLONG tag_k_32x = k & (~31); + + BLASLONG LDA_4x = lda*4; + BLASLONG LDA_8x = lda*8; + BLASLONG LDA_12x = lda*12; + BLASLONG LDA_16x = lda*16; + + src_addr0 = A; + src_addr1 = A + lda; + src_addr2 = A + lda*2; + src_addr3 = A + lda*3; + dst_addr0 = block_A; + dst_addr1 = block_A + 32*16; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3; + __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3; + __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3; + + __m512i M512_EPI64_2 = _mm512_set1_epi64(2); + __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2); + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + for (int i = 0; i < 2; i++) { + // Load and preprocess 1st 4 rows + array512_way0_0 = _mm512_loadu_si512(src_addr0+idx_k); + array512_way0_1 = _mm512_loadu_si512(src_addr1+idx_k); + array512_way0_2 = _mm512_loadu_si512(src_addr2+idx_k); + array512_way0_3 = _mm512_loadu_si512(src_addr3+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1); + array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1); + array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3); + array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3); + array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 2nd 4 rows + array512_way1_0 = _mm512_loadu_si512(src_addr0+LDA_4x+idx_k); + array512_way1_1 = _mm512_loadu_si512(src_addr1+LDA_4x+idx_k); + array512_way1_2 = _mm512_loadu_si512(src_addr2+LDA_4x+idx_k); + array512_way1_3 = _mm512_loadu_si512(src_addr3+LDA_4x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1); + array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1); + array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3); + array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3); + array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 3rd 4 rows + array512_way2_0 = _mm512_loadu_si512(src_addr0+LDA_8x+idx_k); + array512_way2_1 = _mm512_loadu_si512(src_addr1+LDA_8x+idx_k); + array512_way2_2 = _mm512_loadu_si512(src_addr2+LDA_8x+idx_k); + array512_way2_3 = _mm512_loadu_si512(src_addr3+LDA_8x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1); + array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1); + array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3); + array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3); + array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 4th 4 rows + array512_way3_0 = _mm512_loadu_si512(src_addr0+LDA_12x+idx_k); + array512_way3_1 = _mm512_loadu_si512(src_addr1+LDA_12x+idx_k); + array512_way3_2 = _mm512_loadu_si512(src_addr2+LDA_12x+idx_k); + array512_way3_3 = _mm512_loadu_si512(src_addr3+LDA_12x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1); + array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1); + array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3); + array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3); + array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Compose and store the 0/1 and 16/17 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 2/3 and 18/19 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 4/5 and 20/21 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 6/7 and 22/23 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 8/9 and 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 10/11 and 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 12/13 and 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 64; + dst_addr1 += 64; + + // Compose and store the 14/15 and 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + + src_addr0 += LDA_16x; + src_addr1 += LDA_16x; + src_addr2 += LDA_16x; + src_addr3 += LDA_16x; + dst_addr0 -= (64*7 - 32); + dst_addr1 -= (64*7 - 32); + } + src_addr0 -= (LDA_16x*2); + src_addr1 -= (LDA_16x*2); + src_addr2 -= (LDA_16x*2); + src_addr3 -= (LDA_16x*2); + dst_addr0 += (32*30); + dst_addr1 += (32*30); + } + + if (tag_k_32x != k) { + int k_rem = k - tag_k_32x; + unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem)); + __m512i array512[16]; + + bfloat16 * dst_addr_tmp = dst_addr0; + + for (int i = 0; i < 2; i++) { + // Load and preprocess 1st 4 rows + array512[0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[0], array512[1]); + array512_1 = _mm512_unpackhi_epi32(array512[0], array512[1]); + array512_2 = _mm512_unpacklo_epi32(array512[2], array512[3]); + array512_3 = _mm512_unpackhi_epi32(array512[2], array512[3]); + array512[0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[3] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 2nd 4 rows + array512[4] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[5] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[6] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[7] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[4], array512[5]); + array512_1 = _mm512_unpackhi_epi32(array512[4], array512[5]); + array512_2 = _mm512_unpacklo_epi32(array512[6], array512[7]); + array512_3 = _mm512_unpackhi_epi32(array512[6], array512[7]); + array512[4] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[5] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[6] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[7] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 3rd 4 rows + array512[8] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[9] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[10] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[11] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[8], array512[9]); + array512_1 = _mm512_unpackhi_epi32(array512[8], array512[9]); + array512_2 = _mm512_unpacklo_epi32(array512[10], array512[11]); + array512_3 = _mm512_unpackhi_epi32(array512[10], array512[11]); + array512[8] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[9] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[10] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[11] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 4th 4 rows + array512[12] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[13] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[14] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[15] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[12], array512[13]); + array512_1 = _mm512_unpackhi_epi32(array512[12], array512[13]); + array512_2 = _mm512_unpacklo_epi32(array512[14], array512[15]); + array512_3 = _mm512_unpackhi_epi32(array512[14], array512[15]); + array512[12] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[13] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[14] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[15] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // array512_01_1617_0, array512_01_1617_1, array512_89_2425_0, array512_89_2425_1; + // Half-compose of 0/1, 16/17, 8/9, 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512[0], permute_lo_idx, array512[4]); + array512_1 = _mm512_permutex2var_epi64(array512[8], permute_lo_idx, array512[12]); + array512_2 = _mm512_permutex2var_epi64(array512[0], permute_hi_idx, array512[4]); + array512_3 = _mm512_permutex2var_epi64(array512[8], permute_hi_idx, array512[12]); + array512[0] = array512_0; // 1st 8 pairs of col 0/1, and 1st 8 pairs of col 16/17 + array512[4] = array512_1; // 2nd 8 pairs of col 0/1, and 2nd 8 pairs of col 16/17 + array512[8] = array512_2; // 1st 8 pairs of col 8/9, and 1st 8 pairs of col 24/25 + array512[12] = array512_3; // 2nd 8 pairs of col 8/9, and 2nd 8 pairs of col 24/25 + + // Half-compose of 2/3, 18/19, 10/11, 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512[1], permute_lo_idx, array512[5]); + array512_1 = _mm512_permutex2var_epi64(array512[9], permute_lo_idx, array512[13]); + array512_2 = _mm512_permutex2var_epi64(array512[1], permute_hi_idx, array512[5]); + array512_3 = _mm512_permutex2var_epi64(array512[9], permute_hi_idx, array512[13]); + array512[1] = array512_0; // 1st 8 pairs of col 2/3, and 1st 8 pairs of col 18/19 + array512[5] = array512_1; // 2nd 8 pairs of col 2/3, and 2nd 8 pairs of col 18/19 + array512[9] = array512_2; // 1st 8 pairs of col 10/11, and 1st 8 pairs of col 26/27 + array512[13] = array512_3; // 2nd 8 pairs of col 10/11, and 2nd 8 pairs of col 26/27 + + // Half-compose of 4/5, 20/21, 12/13, 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512[2], permute_lo_idx, array512[6]); + array512_1 = _mm512_permutex2var_epi64(array512[10], permute_lo_idx, array512[14]); + array512_2 = _mm512_permutex2var_epi64(array512[2], permute_hi_idx, array512[6]); + array512_3 = _mm512_permutex2var_epi64(array512[10], permute_hi_idx, array512[14]); + array512[2] = array512_0; // 1st 8 pairs of col 4/5, and 1st 8 pairs of col 20/21 + array512[6] = array512_1; // 2nd 8 pairs of col 4/5, and 2nd 8 pairs of col 20/21 + array512[10] = array512_2; // 1st 8 pairs of col 12/13, and 1st 8 pairs of col 28/29 + array512[14] = array512_3; // 2nd 8 pairs of col 12/13, and 2nd 8 pairs of col 28/29 + + // Half-compose of 6/7, 22/23, 14/15, 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512[3], permute_lo_idx, array512[7]); + array512_1 = _mm512_permutex2var_epi64(array512[11], permute_lo_idx, array512[15]); + array512_2 = _mm512_permutex2var_epi64(array512[3], permute_hi_idx, array512[7]); + array512_3 = _mm512_permutex2var_epi64(array512[11], permute_hi_idx, array512[15]); + array512[3] = array512_0; // 1st 8 pairs of col 6/7, and 1st 8 pairs of col 22/23 + array512[7] = array512_1; // 2nd 8 pairs of col 6/7, and 2nd 8 pairs of col 22/23 + array512[11] = array512_2; // 1st 8 pairs of col 14/15, and 1st 8 pairs of col 30/31 + array512[15] = array512_3; // 2nd 8 pairs of col 14/15, and 2nd 8 pairs of col 30/31 + + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[0], _mm512_castsi512_si256(array512[4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 2/3 cols + array512_0 = _mm512_inserti64x4(array512[1], _mm512_castsi512_si256(array512[5]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 4/5 cols + array512_0 = _mm512_inserti64x4(array512[2], _mm512_castsi512_si256(array512[6]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 6/7 cols + array512_0 = _mm512_inserti64x4(array512[3], _mm512_castsi512_si256(array512[7]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 8/9 cols + array512_0 = _mm512_inserti64x4(array512[8], _mm512_castsi512_si256(array512[12]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 10/11 cols + array512_0 = _mm512_inserti64x4(array512[9], _mm512_castsi512_si256(array512[13]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 12/13 cols + array512_0 = _mm512_inserti64x4(array512[10], _mm512_castsi512_si256(array512[14]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store the 14/15 cols + array512_0 = _mm512_inserti64x4(array512[11], _mm512_castsi512_si256(array512[15]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + + // Compose and store 16 ~ k_rem cols + int idx_length = (k_rem + 1 - 16) >> 1; + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } + + dst_addr0 = dst_addr_tmp + 32; + } + } +} + +// K=Any number but will be processed based on 32, 16> 1; + unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem)); + bfloat16 * dst_addr_tmp = dst_addr0; + + for (int j = 0; j < 4; j++) { + int array_idx = j*4; + // Load and preprocess 4 rows + array512[array_idx+0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[array_idx+1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[array_idx+2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[array_idx+3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]); + array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]); + array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + } + + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]); + array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]); + array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]); + array512[j+0] = array512_0; // 1st 8 pairs of col 0/1|2/3|4/5|6/7, and 1st 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+4] = array512_1; // 2nd 8 pairs of col 0/1|2/3|4/5|6/7, and 2nd 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+8] = array512_2; // 1st 8 pairs of col 8/9|10/11|12/13|14/15, and 1st 8 pairs of col 24/25|26/27|28/29|30/31 + array512[j+12] = array512_3; // 2nd 8 pairs of col 8/9|10/11|12/13|14/15, and 2nd 8 pairs of col 24/25|26/27|28/29|30/31 + } + + for (int j = 0; j < 4; j++) { + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int j = 8; j < 12; j++) { + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + // Compose and store 16 ~ k_rem cols + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } + + dst_addr0 = dst_addr_tmp + 32; + + for (int j = 0; j < m_rem; j++) { + array512[j] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+j*lda+tag_k_32x); + } + for (int j = m_rem; j < 16; j++) { + array512[j] = _mm512_setzero_si512(); + } + + for (int j = 0; j < 4; j++) { + int array_idx = j*4; + array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]); + array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]); + array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3); + } + + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]); + array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]); + array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]); + array512[j+0] = array512_0; // 1st 8 pairs of col 0/1|2/3|4/5|6/7, and 1st 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+4] = array512_1; // 2nd 8 pairs of col 0/1|2/3|4/5|6/7, and 2nd 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+8] = array512_2; // 1st 8 pairs of col 8/9|10/11|12/13|14/15, and 1st 8 pairs of col 24/25|26/27|28/29|30/31 + array512[j+12] = array512_3; // 2nd 8 pairs of col 8/9|10/11|12/13|14/15, and 2nd 8 pairs of col 24/25|26/27|28/29|30/31 + } + + for (int j = 0; j < 4; j++) { + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int j = 8; j < 12; j++) { + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + // Compose and store 16 ~ k_rem cols + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 64; + } + } + } +} + +// K=Any number but will be processed based on 32, M=16 +void COL_MAJOR_ITCOPY_KERNEL_Kx16(BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) +{ + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * dst_addr0, * dst_addr1; + + BLASLONG tag_k_32x = k & (~31); + + BLASLONG LDA_4x = lda*4; + BLASLONG LDA_8x = lda*8; + BLASLONG LDA_12x = lda*12; + + src_addr0 = A; + src_addr1 = A + lda; + src_addr2 = A + lda*2; + src_addr3 = A + lda*3; + dst_addr0 = block_A; + dst_addr1 = block_A + 32*8; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3; + __m512i array512_way2_0, array512_way2_1, array512_way2_2, array512_way2_3; + __m512i array512_way3_0, array512_way3_1, array512_way3_2, array512_way3_3; + + __m512i M512_EPI64_2 = _mm512_set1_epi64(2); + __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2); + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + // Load and preprocess 1st 4 rows + array512_way0_0 = _mm512_loadu_si512(src_addr0+idx_k); + array512_way0_1 = _mm512_loadu_si512(src_addr1+idx_k); + array512_way0_2 = _mm512_loadu_si512(src_addr2+idx_k); + array512_way0_3 = _mm512_loadu_si512(src_addr3+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way0_0, array512_way0_1); + array512_1 = _mm512_unpackhi_epi32(array512_way0_0, array512_way0_1); + array512_2 = _mm512_unpacklo_epi32(array512_way0_2, array512_way0_3); + array512_3 = _mm512_unpackhi_epi32(array512_way0_2, array512_way0_3); + array512_way0_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way0_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way0_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way0_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 2nd 4 rows + array512_way1_0 = _mm512_loadu_si512(src_addr0+LDA_4x+idx_k); + array512_way1_1 = _mm512_loadu_si512(src_addr1+LDA_4x+idx_k); + array512_way1_2 = _mm512_loadu_si512(src_addr2+LDA_4x+idx_k); + array512_way1_3 = _mm512_loadu_si512(src_addr3+LDA_4x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way1_0, array512_way1_1); + array512_1 = _mm512_unpackhi_epi32(array512_way1_0, array512_way1_1); + array512_2 = _mm512_unpacklo_epi32(array512_way1_2, array512_way1_3); + array512_3 = _mm512_unpackhi_epi32(array512_way1_2, array512_way1_3); + array512_way1_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way1_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way1_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way1_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 3rd 4 rows + array512_way2_0 = _mm512_loadu_si512(src_addr0+LDA_8x+idx_k); + array512_way2_1 = _mm512_loadu_si512(src_addr1+LDA_8x+idx_k); + array512_way2_2 = _mm512_loadu_si512(src_addr2+LDA_8x+idx_k); + array512_way2_3 = _mm512_loadu_si512(src_addr3+LDA_8x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way2_0, array512_way2_1); + array512_1 = _mm512_unpackhi_epi32(array512_way2_0, array512_way2_1); + array512_2 = _mm512_unpacklo_epi32(array512_way2_2, array512_way2_3); + array512_3 = _mm512_unpackhi_epi32(array512_way2_2, array512_way2_3); + array512_way2_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way2_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way2_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way2_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Load and preprocess 4th 4 rows + array512_way3_0 = _mm512_loadu_si512(src_addr0+LDA_12x+idx_k); + array512_way3_1 = _mm512_loadu_si512(src_addr1+LDA_12x+idx_k); + array512_way3_2 = _mm512_loadu_si512(src_addr2+LDA_12x+idx_k); + array512_way3_3 = _mm512_loadu_si512(src_addr3+LDA_12x+idx_k); + array512_0 = _mm512_unpacklo_epi32(array512_way3_0, array512_way3_1); + array512_1 = _mm512_unpackhi_epi32(array512_way3_0, array512_way3_1); + array512_2 = _mm512_unpacklo_epi32(array512_way3_2, array512_way3_3); + array512_3 = _mm512_unpackhi_epi32(array512_way3_2, array512_way3_3); + array512_way3_0 = _mm512_unpacklo_epi64(array512_0, array512_2); + array512_way3_1 = _mm512_unpackhi_epi64(array512_0, array512_2); + array512_way3_2 = _mm512_unpacklo_epi64(array512_1, array512_3); + array512_way3_3 = _mm512_unpackhi_epi64(array512_1, array512_3); + + // Compose and store the 0/1 and 16/17 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_lo_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_lo_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 2/3 and 18/19 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_lo_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_lo_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 4/5 and 20/21 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_lo_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_lo_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 6/7 and 22/23 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_lo_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_lo_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 8/9 and 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_hi_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way2_0, permute_hi_idx, array512_way3_0); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 10/11 and 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_1, permute_hi_idx, array512_way1_1); + array512_1 = _mm512_permutex2var_epi64(array512_way2_1, permute_hi_idx, array512_way3_1); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 12/13 and 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_2, permute_hi_idx, array512_way1_2); + array512_1 = _mm512_permutex2var_epi64(array512_way2_2, permute_hi_idx, array512_way3_2); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + + // Compose and store the 14/15 and 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512_way0_3, permute_hi_idx, array512_way1_3); + array512_1 = _mm512_permutex2var_epi64(array512_way2_3, permute_hi_idx, array512_way3_3); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32*9; + dst_addr1 += 32*9; + } + + if (tag_k_32x != k) { + int k_rem = k - tag_k_32x; + unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem)); + __m512i array512[16]; + + // Load and preprocess 1st 4 rows + array512[0] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[1] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[2] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[3] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[0], array512[1]); + array512_1 = _mm512_unpackhi_epi32(array512[0], array512[1]); + array512_2 = _mm512_unpacklo_epi32(array512[2], array512[3]); + array512_3 = _mm512_unpackhi_epi32(array512[2], array512[3]); + array512[0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[3] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 2nd 4 rows + array512[4] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[5] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[6] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[7] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[4], array512[5]); + array512_1 = _mm512_unpackhi_epi32(array512[4], array512[5]); + array512_2 = _mm512_unpacklo_epi32(array512[6], array512[7]); + array512_3 = _mm512_unpackhi_epi32(array512[6], array512[7]); + array512[4] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[5] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[6] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[7] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 3rd 4 rows + array512[8] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[9] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[10] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[11] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[8], array512[9]); + array512_1 = _mm512_unpackhi_epi32(array512[8], array512[9]); + array512_2 = _mm512_unpacklo_epi32(array512[10], array512[11]); + array512_3 = _mm512_unpackhi_epi32(array512[10], array512[11]); + array512[8] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[9] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[10] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[11] = _mm512_unpackhi_epi64(array512_1, array512_3); + src_addr0 += LDA_4x; + src_addr1 += LDA_4x; + src_addr2 += LDA_4x; + src_addr3 += LDA_4x; + + // Load and preprocess 4th 4 rows + array512[12] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512[13] = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512[14] = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512[15] = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + array512_0 = _mm512_unpacklo_epi32(array512[12], array512[13]); + array512_1 = _mm512_unpackhi_epi32(array512[12], array512[13]); + array512_2 = _mm512_unpacklo_epi32(array512[14], array512[15]); + array512_3 = _mm512_unpackhi_epi32(array512[14], array512[15]); + array512[12] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[13] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[14] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[15] = _mm512_unpackhi_epi64(array512_1, array512_3); + + // array512_01_1617_0, array512_01_1617_1, array512_89_2425_0, array512_89_2425_1; + // Half-compose of 0/1, 16/17, 8/9, 24/25 cols + array512_0 = _mm512_permutex2var_epi64(array512[0], permute_lo_idx, array512[4]); + array512_1 = _mm512_permutex2var_epi64(array512[8], permute_lo_idx, array512[12]); + array512_2 = _mm512_permutex2var_epi64(array512[0], permute_hi_idx, array512[4]); + array512_3 = _mm512_permutex2var_epi64(array512[8], permute_hi_idx, array512[12]); + array512[0] = array512_0; // 1st 8 pairs of col 0/1, and 1st 8 pairs of col 16/17 + array512[4] = array512_1; // 2nd 8 pairs of col 0/1, and 2nd 8 pairs of col 16/17 + array512[8] = array512_2; // 1st 8 pairs of col 8/9, and 1st 8 pairs of col 24/25 + array512[12] = array512_3; // 2nd 8 pairs of col 8/9, and 2nd 8 pairs of col 24/25 + + // Half-compose of 2/3, 18/19, 10/11, 26/27 cols + array512_0 = _mm512_permutex2var_epi64(array512[1], permute_lo_idx, array512[5]); + array512_1 = _mm512_permutex2var_epi64(array512[9], permute_lo_idx, array512[13]); + array512_2 = _mm512_permutex2var_epi64(array512[1], permute_hi_idx, array512[5]); + array512_3 = _mm512_permutex2var_epi64(array512[9], permute_hi_idx, array512[13]); + array512[1] = array512_0; // 1st 8 pairs of col 2/3, and 1st 8 pairs of col 18/19 + array512[5] = array512_1; // 2nd 8 pairs of col 2/3, and 2nd 8 pairs of col 18/19 + array512[9] = array512_2; // 1st 8 pairs of col 10/11, and 1st 8 pairs of col 26/27 + array512[13] = array512_3; // 2nd 8 pairs of col 10/11, and 2nd 8 pairs of col 26/27 + + // Half-compose of 4/5, 20/21, 12/13, 28/29 cols + array512_0 = _mm512_permutex2var_epi64(array512[2], permute_lo_idx, array512[6]); + array512_1 = _mm512_permutex2var_epi64(array512[10], permute_lo_idx, array512[14]); + array512_2 = _mm512_permutex2var_epi64(array512[2], permute_hi_idx, array512[6]); + array512_3 = _mm512_permutex2var_epi64(array512[10], permute_hi_idx, array512[14]); + array512[2] = array512_0; // 1st 8 pairs of col 4/5, and 1st 8 pairs of col 20/21 + array512[6] = array512_1; // 2nd 8 pairs of col 4/5, and 2nd 8 pairs of col 20/21 + array512[10] = array512_2; // 1st 8 pairs of col 12/13, and 1st 8 pairs of col 28/29 + array512[14] = array512_3; // 2nd 8 pairs of col 12/13, and 2nd 8 pairs of col 28/29 + + // Half-compose of 6/7, 22/23, 14/15, 30/31 cols + array512_0 = _mm512_permutex2var_epi64(array512[3], permute_lo_idx, array512[7]); + array512_1 = _mm512_permutex2var_epi64(array512[11], permute_lo_idx, array512[15]); + array512_2 = _mm512_permutex2var_epi64(array512[3], permute_hi_idx, array512[7]); + array512_3 = _mm512_permutex2var_epi64(array512[11], permute_hi_idx, array512[15]); + array512[3] = array512_0; // 1st 8 pairs of col 6/7, and 1st 8 pairs of col 22/23 + array512[7] = array512_1; // 2nd 8 pairs of col 6/7, and 2nd 8 pairs of col 22/23 + array512[11] = array512_2; // 1st 8 pairs of col 14/15, and 1st 8 pairs of col 30/31 + array512[15] = array512_3; // 2nd 8 pairs of col 14/15, and 2nd 8 pairs of col 30/31 + + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[0], _mm512_castsi512_si256(array512[4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 2/3 cols + array512_0 = _mm512_inserti64x4(array512[1], _mm512_castsi512_si256(array512[5]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 4/5 cols + array512_0 = _mm512_inserti64x4(array512[2], _mm512_castsi512_si256(array512[6]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 6/7 cols + array512_0 = _mm512_inserti64x4(array512[3], _mm512_castsi512_si256(array512[7]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 8/9 cols + array512_0 = _mm512_inserti64x4(array512[8], _mm512_castsi512_si256(array512[12]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 10/11 cols + array512_0 = _mm512_inserti64x4(array512[9], _mm512_castsi512_si256(array512[13]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 12/13 cols + array512_0 = _mm512_inserti64x4(array512[10], _mm512_castsi512_si256(array512[14]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store the 14/15 cols + array512_0 = _mm512_inserti64x4(array512[11], _mm512_castsi512_si256(array512[15]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + + // Compose and store 16 ~ k_rem cols + int idx_length = (k_rem + 1 - 16) >> 1; + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + } + } +} + +// K=Any number but will be processed based on 32, M<=16 +void COL_MAJOR_ITCOPY_KERNEL_Kx16m(BLASLONG m, BLASLONG k, bfloat16 * A, BLASLONG lda, bfloat16 * block_A) +{ + bfloat16 * src_addr0; + bfloat16 * dst_addr0, * dst_addr1; + + BLASLONG tag_k_32x = k & (~31); + + src_addr0 = A; + dst_addr0 = block_A; + dst_addr1 = block_A + 32*8; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512[16]; + + __m512i M512_EPI64_2 = _mm512_set1_epi64(2); + __m512i permute_lo_idx = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0); + __m512i permute_hi_idx = _mm512_add_epi64(permute_lo_idx, M512_EPI64_2); + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + for (int j = 0; j < m; j++) { + array512[j] = _mm512_loadu_si512(src_addr0+j*lda+idx_k); + } + for (int j = m; j < 16; j++) { + array512[j] = _mm512_setzero_si512(); + } + + for (int j = 0; j < 4; j++) { + int array_idx = j*4; + array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]); + array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]); + array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3); + } + + // Compose and store the 0/1, 2/3, 4/5, 6/7 and 16/17, 18/19, 20/21, 22/23 cols + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + } + + // Compose and store the 8/9, 10/11, 12/13, 14/15 and 24/25, 26/27, 28/29, 30/31 cols + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]); + array512_2 = _mm512_inserti64x4(array512_0, _mm512_castsi512_si256(array512_1), 0x1); + array512_3 = _mm512_inserti64x4(array512_1, _mm512_extracti64x4_epi64(array512_0, 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_2); + _mm512_storeu_si512(dst_addr1, array512_3); + dst_addr0 += 32; + dst_addr1 += 32; + } + + dst_addr0 += 32*8; + dst_addr1 += 32*8; + } + + if (tag_k_32x != k) { + int k_rem = k - tag_k_32x; + unsigned int tail_mask = (((unsigned int)0xffffffff) >> (32-k_rem)); + + for (int j = 0; j < m; j++) { + array512[j] = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+j*lda+tag_k_32x); + } + for (int j = m; j < 16; j++) { + array512[j] = _mm512_setzero_si512(); + } + + for (int j = 0; j < 4; j++) { + int array_idx = j*4; + array512_0 = _mm512_unpacklo_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_1 = _mm512_unpackhi_epi32(array512[array_idx+0], array512[array_idx+1]); + array512_2 = _mm512_unpacklo_epi32(array512[array_idx+2], array512[array_idx+3]); + array512_3 = _mm512_unpackhi_epi32(array512[array_idx+2], array512[array_idx+3]); + array512[array_idx+0] = _mm512_unpacklo_epi64(array512_0, array512_2); + array512[array_idx+1] = _mm512_unpackhi_epi64(array512_0, array512_2); + array512[array_idx+2] = _mm512_unpacklo_epi64(array512_1, array512_3); + array512[array_idx+3] = _mm512_unpackhi_epi64(array512_1, array512_3); + } + + for (int j = 0; j < 4; j++) { + array512_0 = _mm512_permutex2var_epi64(array512[j+0], permute_lo_idx, array512[j+4]); + array512_1 = _mm512_permutex2var_epi64(array512[j+8], permute_lo_idx, array512[j+12]); + array512_2 = _mm512_permutex2var_epi64(array512[j+0], permute_hi_idx, array512[j+4]); + array512_3 = _mm512_permutex2var_epi64(array512[j+8], permute_hi_idx, array512[j+12]); + array512[j+0] = array512_0; // 1st 8 pairs of col 0/1|2/3|4/5|6/7, and 1st 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+4] = array512_1; // 2nd 8 pairs of col 0/1|2/3|4/5|6/7, and 2nd 8 pairs of col 16/17|18/19|20/21|22/23 + array512[j+8] = array512_2; // 1st 8 pairs of col 8/9|10/11|12/13|14/15, and 1st 8 pairs of col 24/25|26/27|28/29|30/31 + array512[j+12] = array512_3; // 2nd 8 pairs of col 8/9|10/11|12/13|14/15, and 2nd 8 pairs of col 24/25|26/27|28/29|30/31 + } + + for (int j = 0; j < 4; j++) { + // Compose and store the 0/1 cols + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + + for (int j = 8; j < 12; j++) { + array512_0 = _mm512_inserti64x4(array512[j], _mm512_castsi512_si256(array512[j+4]), 0x1); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + + // Compose and store 16 ~ k_rem cols + int idx_length = (k_rem + 1 - 16) >> 1; + if (idx_length > 4) { + for (int idx_k = 0; idx_k < 4; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + + for (int idx_k = 4; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+8], _mm512_extracti64x4_epi64(array512[idx_k+4], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + } else { + for (int idx_k = 0; idx_k < idx_length; idx_k++) { + array512_0 = _mm512_inserti64x4(array512[idx_k+4], _mm512_extracti64x4_epi64(array512[idx_k], 0x1), 0x0); + _mm512_storeu_si512(dst_addr0, array512_0); + dst_addr0 += 32; + } + } + } +} + +// COL_MAJOR_ONCOPY_KERNEL_16x32 behaves exactly the same as COL_MAJOR_ITCOPY_KERNEL_Kx16 +#define COL_MAJOR_ONCOPY_KERNEL_16x32 COL_MAJOR_ITCOPY_KERNEL_Kx16 + void COL_MAJOR_ONCOPY_KERNEL_8x32(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) { BLASLONG tag_k_32x = k & (~31); - BLASLONG idx_src_base0, idx_src_base1, idx_src_base2, idx_src_base3, idx_src_base4, idx_src_base5, idx_src_base6, idx_src_base7; - BLASLONG idx_target_base0; - idx_src_base0 = 0; - idx_src_base1 = 1*ldb; - idx_src_base2 = 2*ldb; - idx_src_base3 = 3*ldb; - idx_src_base4 = 4*ldb; - idx_src_base5 = 5*ldb; - idx_src_base6 = 6*ldb; - idx_src_base7 = 7*ldb; - idx_target_base0 = 0; + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3, * src_addr4, * src_addr5, * src_addr6, * src_addr7; + bfloat16 * dst_addr0; + + unsigned char blend_mask = (((unsigned char)0xcc)); + __m512i permute_idx = _mm512_set_epi64(13, 12, 7, 6, 9, 8, 3, 2); + + src_addr0 = B; + src_addr1 = src_addr0 + 1*ldb; + src_addr2 = src_addr0 + 2*ldb; + src_addr3 = src_addr0 + 3*ldb; + src_addr4 = src_addr0 + 4*ldb; + src_addr5 = src_addr0 + 5*ldb; + src_addr6 = src_addr0 + 6*ldb; + src_addr7 = src_addr0 + 7*ldb; + dst_addr0 = block_B; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + __m512i array512_way1_0, array512_way1_1, array512_way1_2, array512_way1_3; for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base1+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_loadu_si512(&B[idx_src_base2+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_loadu_si512(&B[idx_src_base3+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_loadu_si512(&B[idx_src_base4+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_loadu_si512(&B[idx_src_base5+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_loadu_si512(&B[idx_src_base6+idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_loadu_si512(&B[idx_src_base7+idx_k])); - idx_target_base0 += 32*8; + array512_0 = _mm512_loadu_si512(src_addr0+idx_k); + array512_1 = _mm512_loadu_si512(src_addr1+idx_k); + array512_2 = _mm512_loadu_si512(src_addr2+idx_k); + array512_3 = _mm512_loadu_si512(src_addr3+idx_k); + + array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2); + array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2); + array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3); + array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3); + + array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88); + array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd); + array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88); + array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd); + + array512_0 = _mm512_loadu_si512(src_addr4+idx_k); + array512_1 = _mm512_loadu_si512(src_addr5+idx_k); + array512_2 = _mm512_loadu_si512(src_addr6+idx_k); + array512_3 = _mm512_loadu_si512(src_addr7+idx_k); + + array512_way1_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way1_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way1_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way1_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way1_0, array512_way1_2); + array512_1 = _mm512_unpackhi_epi64(array512_way1_0, array512_way1_2); + array512_2 = _mm512_unpacklo_epi64(array512_way1_1, array512_way1_3); + array512_3 = _mm512_unpackhi_epi64(array512_way1_1, array512_way1_3); + + array512_way1_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x22); + array512_way1_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x77); + array512_way1_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x22); + array512_way1_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x77); + + array512_0 = _mm512_mask_blend_epi64(blend_mask, array512_way0_0, array512_way1_0); + array512_1 = _mm512_mask_blend_epi64(blend_mask, array512_way0_1, array512_way1_1); + array512_2 = _mm512_mask_blend_epi64(blend_mask, array512_way0_2, array512_way1_2); + array512_3 = _mm512_mask_blend_epi64(blend_mask, array512_way0_3, array512_way1_3); + _mm512_storeu_si512(dst_addr0, array512_0); + _mm512_storeu_si512(dst_addr0+32, array512_1); + _mm512_storeu_si512(dst_addr0+64, array512_2); + _mm512_storeu_si512(dst_addr0+96, array512_3); + + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way0_1, permute_idx, array512_way1_1); + array512_2 = _mm512_permutex2var_epi64(array512_way0_2, permute_idx, array512_way1_2); + array512_3 = _mm512_permutex2var_epi64(array512_way0_3, permute_idx, array512_way1_3); + _mm512_storeu_si512(dst_addr0+128, array512_0); + _mm512_storeu_si512(dst_addr0+160, array512_1); + _mm512_storeu_si512(dst_addr0+192, array512_2); + _mm512_storeu_si512(dst_addr0+224, array512_3); + + dst_addr0 += 256; } if (tag_k_32x != k) { unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x))); __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base1+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*2], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base2+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*3], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base3+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*4], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base4+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*5], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base5+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*6], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base6+tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*7], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base7+tag_k_32x])); + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + + array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2); + array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2); + array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3); + array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3); + + array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88); + array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd); + array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88); + array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd); + + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr4+tag_k_32x); + array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr5+tag_k_32x); + array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr6+tag_k_32x); + array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr7+tag_k_32x); + + array512_way1_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way1_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way1_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way1_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way1_0, array512_way1_2); + array512_1 = _mm512_unpackhi_epi64(array512_way1_0, array512_way1_2); + array512_2 = _mm512_unpacklo_epi64(array512_way1_1, array512_way1_3); + array512_3 = _mm512_unpackhi_epi64(array512_way1_1, array512_way1_3); + + array512_way1_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x22); + array512_way1_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x77); + array512_way1_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x22); + array512_way1_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x77); + + + array512_0 = _mm512_mask_blend_epi64(blend_mask, array512_way0_0, array512_way1_0); + array512_1 = _mm512_mask_blend_epi64(blend_mask, array512_way0_1, array512_way1_1); + array512_2 = _mm512_mask_blend_epi64(blend_mask, array512_way0_2, array512_way1_2); + array512_3 = _mm512_mask_blend_epi64(blend_mask, array512_way0_3, array512_way1_3); + _mm512_storeu_si512(dst_addr0, array512_0); + _mm512_storeu_si512(dst_addr0+32, array512_1); + _mm512_storeu_si512(dst_addr0+64, array512_2); + _mm512_storeu_si512(dst_addr0+96, array512_3); + + array512_0 = _mm512_permutex2var_epi64(array512_way0_0, permute_idx, array512_way1_0); + array512_1 = _mm512_permutex2var_epi64(array512_way0_1, permute_idx, array512_way1_1); + array512_2 = _mm512_permutex2var_epi64(array512_way0_2, permute_idx, array512_way1_2); + array512_3 = _mm512_permutex2var_epi64(array512_way0_3, permute_idx, array512_way1_3); + _mm512_storeu_si512(dst_addr0+128, array512_0); + _mm512_storeu_si512(dst_addr0+160, array512_1); + _mm512_storeu_si512(dst_addr0+192, array512_2); + _mm512_storeu_si512(dst_addr0+224, array512_3); + } +} + +void COL_MAJOR_ONCOPY_KERNEL_4x32(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) +{ + BLASLONG tag_k_32x = k & (~31); + + bfloat16 * src_addr0, * src_addr1, * src_addr2, * src_addr3; + bfloat16 * dst_addr0; + + src_addr0 = B; + src_addr1 = src_addr0 + 1*ldb; + src_addr2 = src_addr0 + 2*ldb; + src_addr3 = src_addr0 + 3*ldb; + dst_addr0 = block_B; + + __m512i array512_0, array512_1, array512_2, array512_3; + __m512i array512_way0_0, array512_way0_1, array512_way0_2, array512_way0_3; + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + array512_0 = _mm512_loadu_si512(src_addr0+idx_k); + array512_1 = _mm512_loadu_si512(src_addr1+idx_k); + array512_2 = _mm512_loadu_si512(src_addr2+idx_k); + array512_3 = _mm512_loadu_si512(src_addr3+idx_k); + + array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2); + array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2); + array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3); + array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3); + + array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88); + array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd); + array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88); + array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd); + + array512_0 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0x88); + array512_1 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0x88); + array512_2 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0xdd); + array512_3 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0xdd); + + _mm512_storeu_si512(dst_addr0, array512_0); + _mm512_storeu_si512(dst_addr0+32, array512_1); + _mm512_storeu_si512(dst_addr0+64, array512_2); + _mm512_storeu_si512(dst_addr0+96, array512_3); + + dst_addr0 += 128; } -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B); -#endif + if (tag_k_32x != k) { + unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x))); + __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); + array512_0 = _mm512_maskz_loadu_epi16(tail_mask, src_addr0+tag_k_32x); + array512_1 = _mm512_maskz_loadu_epi16(tail_mask, src_addr1+tag_k_32x); + array512_2 = _mm512_maskz_loadu_epi16(tail_mask, src_addr2+tag_k_32x); + array512_3 = _mm512_maskz_loadu_epi16(tail_mask, src_addr3+tag_k_32x); + + array512_way0_0 = _mm512_unpacklo_epi32(array512_0, array512_1); + array512_way0_1 = _mm512_unpackhi_epi32(array512_0, array512_1); + array512_way0_2 = _mm512_unpacklo_epi32(array512_2, array512_3); + array512_way0_3 = _mm512_unpackhi_epi32(array512_2, array512_3); + + array512_0 = _mm512_unpacklo_epi64(array512_way0_0, array512_way0_2); + array512_1 = _mm512_unpackhi_epi64(array512_way0_0, array512_way0_2); + array512_2 = _mm512_unpacklo_epi64(array512_way0_1, array512_way0_3); + array512_3 = _mm512_unpackhi_epi64(array512_way0_1, array512_way0_3); + + array512_way0_0 = _mm512_shuffle_i32x4(array512_0, array512_1, 0x88); + array512_way0_2 = _mm512_shuffle_i32x4(array512_0, array512_1, 0xdd); + array512_way0_1 = _mm512_shuffle_i32x4(array512_2, array512_3, 0x88); + array512_way0_3 = _mm512_shuffle_i32x4(array512_2, array512_3, 0xdd); + + array512_0 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0x88); + array512_1 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0x88); + array512_2 = _mm512_shuffle_i32x4(array512_way0_0, array512_way0_1, 0xdd); + array512_3 = _mm512_shuffle_i32x4(array512_way0_2, array512_way0_3, 0xdd); + + _mm512_storeu_si512(dst_addr0, array512_0); + _mm512_storeu_si512(dst_addr0+32, array512_1); + _mm512_storeu_si512(dst_addr0+64, array512_2); + _mm512_storeu_si512(dst_addr0+96, array512_3); + } } void COL_MAJOR_ONCOPY_KERNEL_Nx32(BLASLONG n, BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) { BLASLONG tag_k_32x = k & (~31); BLASLONG tag_n_2x = n & (~1); - BLASLONG idx_src_base0; - BLASLONG idx_target_base0; + + bfloat16 * src_addr0; + bfloat16 * dst_addr0; BLASLONG LDB_2x = 2*ldb; - idx_target_base0 = 0; + src_addr0 = B; + dst_addr0 = block_B; for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { - idx_src_base0 = 0; + src_addr0 = B; for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) { - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_loadu_si512(&B[idx_src_base0 + idx_k])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_loadu_si512(&B[idx_src_base0 + ldb + idx_k])); - idx_src_base0 += LDB_2x; - idx_target_base0 += 64; + _mm512_storeu_si512(dst_addr0, _mm512_loadu_si512(src_addr0 + idx_k)); + _mm512_storeu_si512(dst_addr0 + 32, _mm512_loadu_si512(src_addr0 + ldb + idx_k)); + src_addr0 += LDB_2x; + dst_addr0 += 64; } if (tag_n_2x != n) { - _mm512_storeu_si512(&block_B[idx_target_base0], _mm512_loadu_si512(&B[idx_src_base0 + idx_k])); - idx_target_base0 += 32; + _mm512_storeu_si512(dst_addr0, _mm512_loadu_si512(src_addr0 + idx_k)); + dst_addr0 += 32; } } if (tag_k_32x != k) { unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(k-tag_k_32x))); __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); - idx_src_base0 = 0; + src_addr0 = B; for (BLASLONG idx_n = 0; idx_n < tag_n_2x; idx_n += 2) { - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + tag_k_32x])); - _mm512_storeu_si512(&block_B[idx_target_base0+ 32*1], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + ldb + tag_k_32x])); - idx_src_base0 += LDB_2x; - idx_target_base0 += 64; + _mm512_storeu_si512(dst_addr0, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + tag_k_32x)); + _mm512_storeu_si512(dst_addr0 + 32, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + ldb + tag_k_32x)); + src_addr0 += LDB_2x; + dst_addr0 += 64; } if (tag_n_2x != n) { - _mm512_storeu_si512(&block_B[idx_target_base0], _mm512_maskz_loadu_epi16(tail_mask, &B[idx_src_base0 + tag_k_32x])); + _mm512_storeu_si512(dst_addr0, _mm512_maskz_loadu_epi16(tail_mask, src_addr0 + tag_k_32x)); } } - -#ifdef DEBUG_PROFILE - print_block(BF16_BLOCK_THRES_N, BF16_BLOCK_THRES_K, block_B); -#endif } -// Scale matrix C while beta is not ZERO or ONE -void sbgemm_scal_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST float beta, float *C, OPENBLAS_CONST blasint ldc) +void COL_MAJOR_OTCOPY_KERNEL_Kx8(BLASLONG k, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) { - BLASLONG tag_n_Nx = N & (~3); - BLASLONG tag_n_Mx = M & (~15); + BLASLONG tag_k_2x = k & (~1); + unsigned char tail_mask_value = (unsigned char) 0xff; + __mmask8 tail_mask = *((__mmask8*) &tail_mask_value); + + __m128i array128_0, array128_1, array128_2, array128_3; + + BLASLONG idx_src_base0, idx_src_base1; + BLASLONG idx_target_base0, idx_target_base1; + + BLASLONG LDA_2x = 2*ldb; + BLASLONG BF16_BLOCK_T_M_2x = 2*8; + idx_src_base0 = 0; + idx_src_base1 = ldb; + idx_target_base0 = 0; + idx_target_base1 = 8; + for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { + array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]); + array128_1 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base1]); + array128_2 = _mm_unpacklo_epi16(array128_0, array128_1); + array128_3 = _mm_unpackhi_epi16(array128_0, array128_1); + _mm_storeu_epi32(&block_B[idx_target_base0], array128_2); + _mm_storeu_epi32(&block_B[idx_target_base1], array128_3); + + idx_src_base0 += LDA_2x; + idx_src_base1 += LDA_2x; + idx_target_base0 += BF16_BLOCK_T_M_2x; + idx_target_base1 += BF16_BLOCK_T_M_2x; + } + + if (tag_k_2x != k) { + __m128i ZERO128 = _mm_setzero_si128(); + array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]); + array128_2 = _mm_unpacklo_epi16(array128_0, ZERO128); + array128_3 = _mm_unpackhi_epi16(array128_0, ZERO128); + _mm_storeu_epi32(&block_B[idx_target_base0], array128_2); + _mm_storeu_epi32(&block_B[idx_target_base1], array128_3); + } +} + +void COL_MAJOR_OTCOPY_KERNEL_Kx8m(BLASLONG k, BLASLONG n, bfloat16 * B, BLASLONG ldb, bfloat16 * block_B) +{ + BLASLONG tag_k_2x = k & (~1); + unsigned char tail_mask = (((unsigned char)0xff) >> (8-n)); + + __m128i array128_0, array128_1, array128_2, array128_3; + + BLASLONG idx_src_base0, idx_src_base1; + BLASLONG idx_target_base0, idx_target_base1; + + BLASLONG LDA_2x = 2*ldb; + BLASLONG BF16_BLOCK_T_M_2x = 2*8; + idx_src_base0 = 0; + idx_src_base1 = ldb; + idx_target_base0 = 0; + idx_target_base1 = 8; + for (BLASLONG idx_k = 0; idx_k < tag_k_2x; idx_k += 2) { + array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]); + array128_1 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base1]); + array128_2 = _mm_unpacklo_epi16(array128_0, array128_1); + array128_3 = _mm_unpackhi_epi16(array128_0, array128_1); + _mm_storeu_epi32(&block_B[idx_target_base0], array128_2); + _mm_storeu_epi32(&block_B[idx_target_base1], array128_3); + + idx_src_base0 += LDA_2x; + idx_src_base1 += LDA_2x; + idx_target_base0 += BF16_BLOCK_T_M_2x; + idx_target_base1 += BF16_BLOCK_T_M_2x; + } + + if (tag_k_2x != k) { + __m128i ZERO128 = _mm_setzero_si128(); + array128_0 = _mm_maskz_loadu_epi16(tail_mask, &B[idx_src_base0]); + array128_2 = _mm_unpacklo_epi16(array128_0, ZERO128); + array128_3 = _mm_unpackhi_epi16(array128_0, ZERO128); + _mm_storeu_epi32(&block_B[idx_target_base0], array128_2); + _mm_storeu_epi32(&block_B[idx_target_base1], array128_3); + } +} + +// Scale matrix C when beta is not ZERO or ONE +void sbgemm_scal_operation(BLASLONG M, BLASLONG N, float beta, float *C, BLASLONG ldc) +{ + float * C_addr0 = C; + float * C_addr1 = C + ldc; + float * C_addr2 = C + ldc*2; + float * C_addr3 = C + ldc*3; BLASLONG LDC4x = ldc*4; - BLASLONG idx_base_0 = 0; - BLASLONG idx_base_1 = ldc; - BLASLONG idx_base_2 = ldc*2; - BLASLONG idx_base_3 = ldc*3; - - unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); __m512 array_512_0, array_512_1, array_512_2, array_512_3; + __m512 BETAVECTOR = _mm512_set1_ps(beta); - __m512 BETAVECTOR = _mm512_set1_ps(beta); + BLASLONG tag_n_Nx = N & (~3); + BLASLONG tag_n_Mx = M & (~15); + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); + for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) { + for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { + array_512_0 = _mm512_loadu_ps(C_addr0 + idx_m); + array_512_1 = _mm512_loadu_ps(C_addr1 + idx_m); + array_512_2 = _mm512_loadu_ps(C_addr2 + idx_m); + array_512_3 = _mm512_loadu_ps(C_addr3 + idx_m); - if (Order == CblasColMajor) { - for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) { + array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); + array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1); + array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2); + array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3); + + _mm512_storeu_ps(C_addr0 + idx_m, array_512_0); + _mm512_storeu_ps(C_addr1 + idx_m, array_512_1); + _mm512_storeu_ps(C_addr2 + idx_m, array_512_2); + _mm512_storeu_ps(C_addr3 + idx_m, array_512_3); + } + + if (tag_n_Mx != M) { + array_512_0 = _mm512_maskz_loadu_ps(tail_mask, C_addr0 + tag_n_Mx); + array_512_1 = _mm512_maskz_loadu_ps(tail_mask, C_addr1 + tag_n_Mx); + array_512_2 = _mm512_maskz_loadu_ps(tail_mask, C_addr2 + tag_n_Mx); + array_512_3 = _mm512_maskz_loadu_ps(tail_mask, C_addr3 + tag_n_Mx); + + array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); + array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1); + array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2); + array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3); + + _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, array_512_0); + _mm512_mask_storeu_ps(C_addr1 + tag_n_Mx, tail_mask, array_512_1); + _mm512_mask_storeu_ps(C_addr2 + tag_n_Mx, tail_mask, array_512_2); + _mm512_mask_storeu_ps(C_addr3 + tag_n_Mx, tail_mask, array_512_3); + } + + C_addr0 += LDC4x; + C_addr1 += LDC4x; + C_addr2 += LDC4x; + C_addr3 += LDC4x; + } + + if (tag_n_Nx != N) { + for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) { for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { - array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]); - array_512_1 = _mm512_loadu_ps(&C[idx_base_1+idx_m]); - array_512_2 = _mm512_loadu_ps(&C[idx_base_2+idx_m]); - array_512_3 = _mm512_loadu_ps(&C[idx_base_3+idx_m]); - + array_512_0 = _mm512_loadu_ps(C_addr0 + idx_m); array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); - array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1); - array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2); - array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3); - - _mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0); - _mm512_storeu_ps(&C[idx_base_1+idx_m], array_512_1); - _mm512_storeu_ps(&C[idx_base_2+idx_m], array_512_2); - _mm512_storeu_ps(&C[idx_base_3+idx_m], array_512_3); + _mm512_storeu_ps(C_addr0 + idx_m, array_512_0); } if (tag_n_Mx != M) { - array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]); - array_512_1 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_1+tag_n_Mx]); - array_512_2 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_2+tag_n_Mx]); - array_512_3 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_3+tag_n_Mx]); - + array_512_0 = _mm512_maskz_loadu_ps(tail_mask, C_addr0 + tag_n_Mx); array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); - array_512_1 = _mm512_mul_ps(BETAVECTOR, array_512_1); - array_512_2 = _mm512_mul_ps(BETAVECTOR, array_512_2); - array_512_3 = _mm512_mul_ps(BETAVECTOR, array_512_3); - - _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0); - _mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, array_512_1); - _mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, array_512_2); - _mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, array_512_3); + _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, array_512_0); } - - idx_base_0 += LDC4x; - idx_base_1 += LDC4x; - idx_base_2 += LDC4x; - idx_base_3 += LDC4x; + C_addr0 += ldc; } - - if (tag_n_Nx != N) { - for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) { - for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { - array_512_0 = _mm512_loadu_ps(&C[idx_base_0+idx_m]); - array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); - _mm512_storeu_ps(&C[idx_base_0+idx_m], array_512_0); - } - - if (tag_n_Mx != M) { - array_512_0 = _mm512_maskz_loadu_ps(tail_mask, &C[idx_base_0+tag_n_Mx]); - array_512_0 = _mm512_mul_ps(BETAVECTOR, array_512_0); - _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, array_512_0); - } - idx_base_0 += ldc; - } - } - } else { - } } -// Scale matrix C while beta is not ZERO or ONE -void sbgemm_zero_operation(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, float *C, OPENBLAS_CONST blasint ldc) +// Zero C matrix when Beta is 0 +void sbgemm_zero_operation(BLASLONG M, BLASLONG N, float *C, BLASLONG ldc) { - BLASLONG tag_n_Nx = N & (~3); - BLASLONG tag_n_Mx = M & (~15); + float * C_addr0 = C; + float * C_addr1 = C + ldc; + float * C_addr2 = C + ldc*2; + float * C_addr3 = C + ldc*3; BLASLONG LDC4x = ldc*4; - BLASLONG idx_base_0 = 0; - BLASLONG idx_base_1 = ldc; - BLASLONG idx_base_2 = ldc*2; - BLASLONG idx_base_3 = ldc*3; - - unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); __m512 ZEROVECTOR = _mm512_setzero_ps(); - if (Order == CblasColMajor) { - for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) { + BLASLONG tag_n_Nx = N & (~3); + BLASLONG tag_n_Mx = M & (~15); + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-M+tag_n_Mx)); + for (BLASLONG idx_n = 0; idx_n < tag_n_Nx; idx_n += 4) { + for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { + _mm512_storeu_ps(C_addr0 + idx_m, ZEROVECTOR); + _mm512_storeu_ps(C_addr1 + idx_m, ZEROVECTOR); + _mm512_storeu_ps(C_addr2 + idx_m, ZEROVECTOR); + _mm512_storeu_ps(C_addr3 + idx_m, ZEROVECTOR); + } + + if (tag_n_Mx != M) { + _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, ZEROVECTOR); + _mm512_mask_storeu_ps(C_addr1 + tag_n_Mx, tail_mask, ZEROVECTOR); + _mm512_mask_storeu_ps(C_addr2 + tag_n_Mx, tail_mask, ZEROVECTOR); + _mm512_mask_storeu_ps(C_addr3 + tag_n_Mx, tail_mask, ZEROVECTOR); + } + + C_addr0 += LDC4x; + C_addr1 += LDC4x; + C_addr2 += LDC4x; + C_addr3 += LDC4x; + } + + if (tag_n_Nx != N) { + for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) { for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { - _mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR); - _mm512_storeu_ps(&C[idx_base_1+idx_m], ZEROVECTOR); - _mm512_storeu_ps(&C[idx_base_2+idx_m], ZEROVECTOR); - _mm512_storeu_ps(&C[idx_base_3+idx_m], ZEROVECTOR); + _mm512_storeu_ps(C_addr0 + idx_m, ZEROVECTOR); } if (tag_n_Mx != M) { - _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR); - _mm512_mask_storeu_ps(&C[idx_base_1+tag_n_Mx], tail_mask, ZEROVECTOR); - _mm512_mask_storeu_ps(&C[idx_base_2+tag_n_Mx], tail_mask, ZEROVECTOR); - _mm512_mask_storeu_ps(&C[idx_base_3+tag_n_Mx], tail_mask, ZEROVECTOR); + _mm512_mask_storeu_ps(C_addr0 + tag_n_Mx, tail_mask, ZEROVECTOR); } - - idx_base_0 += LDC4x; - idx_base_1 += LDC4x; - idx_base_2 += LDC4x; - idx_base_3 += LDC4x; + C_addr0 += ldc; } - - if (tag_n_Nx != N) { - for (BLASLONG idx_n = tag_n_Nx; idx_n < N; idx_n++) { - for (BLASLONG idx_m = 0; idx_m < tag_n_Mx; idx_m += 16) { - _mm512_storeu_ps(&C[idx_base_0+idx_m], ZEROVECTOR); - } - - if (tag_n_Mx != M) { - _mm512_mask_storeu_ps(&C[idx_base_0+tag_n_Mx], tail_mask, ZEROVECTOR); - } - idx_base_0 += ldc; - } - } - } else { - } -} \ No newline at end of file +} diff --git a/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c b/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c new file mode 100644 index 000000000..7af51b6d8 --- /dev/null +++ b/kernel/x86_64/sbgemm_kernel_16x4_cooperlake.c @@ -0,0 +1,499 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" + +#define VMOVLDUP(addr, zmm) asm("vmovsldup (%1), %0": "=v"(zmm): "r"(addr)) +#define VMOVHDUP(addr, zmm) asm("vmovshdup (%1), %0": "=v"(zmm): "r"(addr)) +#define BROADCAST64(base, step, n, offset, zmm) \ + if (n == 0) asm("vbroadcastsd %c2(%1), %0": "=v"(zmm): "r"(base), "n"(offset*2)); \ + else asm("vbroadcastsd %c4(%1, %2, %c3), %0": "=v"(zmm): "r"(base), "r"(step), "n"(n*2), "n"(offset*2)) + +#define DECLARE_A_PAIR(A) \ + __m512i A_lo_##A; __m512i A_hi_##A; + +#define LOAD_A_PAIR(A) \ + VMOVLDUP(ptr_a##A, A_lo_##A); \ + VMOVHDUP(ptr_a##A, A_hi_##A); + +#define MASK_LOAD_A_PAIR(A) { \ + __m512 tmp = _mm512_maskz_loadu_ps(mmask, ptr_a##A); \ + A_lo_##A = (__m512i) _mm512_moveldup_ps(tmp); \ + A_hi_##A = (__m512i) _mm512_movehdup_ps(tmp); \ +} + +#define LOAD_A_PAIR_TAIL(A) { \ + __m256i ymm = _mm256_loadu_si256((void *)ptr_a##A); \ + __m512 zmm = (__m512) _mm512_cvtepu16_epi32(ymm); \ + A_lo_##A = (__m512i) _mm512_moveldup_ps(zmm); \ + A_hi_##A = (__m512i) _mm512_movehdup_ps(zmm); \ +} + +#define MASK_LOAD_A_PAIR_TAIL(A) { \ + __m256i ymm = _mm256_maskz_loadu_epi16(mmask, ptr_a##A); \ + __m512 zmm = (__m512) _mm512_cvtepu16_epi32(ymm); \ + A_lo_##A = (__m512i) _mm512_moveldup_ps(zmm); \ + A_hi_##A = (__m512i) _mm512_movehdup_ps(zmm); \ +} + +#define DECLARE_B_PAIR() \ + __m512i B_lo; __m512i B_hi; + +#define PREFETCH_B_STEP 32 +#define PREFETCH_B(Bx, By) \ + if (By == 0) asm("prefetcht0 %c1(%0)": : "r"(ptr_b##Bx), "n"(PREFETCH_B_STEP * 2)); \ + else asm("prefetcht0 %c3(%0, %1, %c2)": : "r"(ptr_b##Bx), "r"(n_blksize), "n"(By*2), "n"(PREFETCH_B_STEP * 2)) + +#define BROADCAST_B_PAIR(Bx, By) \ + BROADCAST64(ptr_b##Bx, n_blksize, By, 0, B_lo); \ + BROADCAST64(ptr_b##Bx, n_blksize, By, 4, B_hi); + +#define MASK_BROADCAST_B_PAIR(Bx, x) {\ + __m128 xmm = _mm_maskz_loadu_ps(nmask, ptr_b##Bx); \ + B_lo = (__m512i) _mm512_broadcastsd_pd((__m128d) xmm); \ + B_hi = (__m512i) _mm512_broadcastsd_pd(_mm_permute_pd((__m128d) xmm, 0x1)); \ +} + +#define BROADCAST_B_PAIR_TAIL(Bx, By) {\ + __m128i xmm = (__m128i) _mm_load_sd((double *)(ptr_b##Bx + n_blksize * By)); \ + xmm = _mm_cvtepu16_epi32(xmm); \ + B_lo = _mm512_broadcast_i32x2(xmm); \ + B_hi = _mm512_broadcast_i32x2((__m128i) _mm_permute_pd((__m128d) xmm, 0x1)); \ +} + +#define MASK_BROADCAST_B_PAIR_TAIL(Bx, By) {\ + __m128i xmm = _mm_maskz_loadu_epi16(nmask, ptr_b##Bx + n_blksize * By); \ + xmm = _mm_cvtepu16_epi32(xmm); \ + B_lo = _mm512_broadcast_i32x2(xmm); \ + B_hi = _mm512_broadcast_i32x2((__m128i) _mm_permute_pd((__m128d) xmm, 0x1)); \ +} + +#define DECLARE_RESULT_4X(A, Bx, By) \ + __m512 result_00_##A##Bx##By = _mm512_setzero_ps(); \ + __m512 result_01_##A##Bx##By = _mm512_setzero_ps(); \ + __m512 result_10_##A##Bx##By = _mm512_setzero_ps(); \ + __m512 result_11_##A##Bx##By = _mm512_setzero_ps(); + +#define FMA(a, b, r) r = _mm512_dpbf16_ps(r, (__m512bh)a, (__m512bh)b) + +#define MATMUL_4X(A, Bx, By) \ + FMA(A_lo_##A, B_lo, result_00_##A##Bx##By); \ + FMA(A_hi_##A, B_lo, result_01_##A##Bx##By); \ + FMA(A_lo_##A, B_hi, result_10_##A##Bx##By); \ + FMA(A_hi_##A, B_hi, result_11_##A##Bx##By); + +#define _STORE_C_2nx16(addr, val0, val1) \ + asm("vfmadd213ps (%1), %2, %0": "+v"(val0) : "r"(addr), "v"(alpha_512)); \ + asm("vfmadd213ps (%1, %3, 4), %2, %0": "+v"(val1) : "r"(addr), "v"(alpha_512), "r"(ldc)); \ + asm("vmovups %0, (%1)": : "v"(val0), "r"(addr)); \ + asm("vmovups %0, (%1, %2, 4)": : "v"(val1), "r"(addr), "r"(ldc)) + +#define _MASK_STORE_C_2nx16(addr, val0, val1) \ + asm("vfmadd213ps (%1), %2, %0 %{%3%} ": "+v"(val0) : "r"(addr), "v"(alpha_512), "k"(mmask)); \ + asm("vfmadd213ps (%1, %3, 4), %2, %0 %{%4%}": "+v"(val1) : "r"(addr), "v"(alpha_512), "r"(ldc), "k"(mmask)); \ + asm("vmovups %0, (%1) %{%2%}": : "v"(val0), "r"(addr), "k"(mmask)); \ + asm("vmovups %0, (%1, %2, 4) %{%3%}": : "v"(val1), "r"(addr), "r"(ldc), "k"(mmask)) + +#define _REORDER_C_2X(result_0, result_1) { \ + __m512 tmp0, tmp1; \ + tmp0 = _mm512_unpacklo_ps(result_0, result_1); \ + tmp1 = _mm512_unpackhi_ps(result_0, result_1); \ + result_0 = (__m512) _mm512_unpacklo_pd((__m512d) tmp0, (__m512d) tmp1); \ + result_1 = (__m512) _mm512_unpackhi_pd((__m512d) tmp0, (__m512d) tmp1); \ +} + +#define _STORE_2X(ptr_c, result_0, result_1) {\ + _REORDER_C_2X(result_0, result_1) \ + _STORE_C_2nx16(ptr_c, result_0, result_1); \ + ptr_c += ldc * 2; \ +} + +#define _MASK_STORE_2X(ptr_c, result_0, result_1) {\ + _REORDER_C_2X(result_0, result_1) \ + _MASK_STORE_C_2nx16(ptr_c, result_0, result_1); \ + ptr_c += ldc * 2; \ +} + +#define STORE_4X(A, Bx, By) { \ + _STORE_2X(ptr_c##A, result_00_##A##Bx##By, result_01_##A##Bx##By); \ + _STORE_2X(ptr_c##A, result_10_##A##Bx##By, result_11_##A##Bx##By); \ +} + +#define MASK_STORE_4X(A, Bx, By) { \ + _MASK_STORE_2X(ptr_c##A, result_00_##A##Bx##By, result_01_##A##Bx##By); \ + _MASK_STORE_2X(ptr_c##A, result_10_##A##Bx##By, result_11_##A##Bx##By); \ +} + +#define _STORE_C_16(addr, val0) \ + asm("vfmadd213ps (%1), %2, %0": "+v"(val0) : "r"(addr), "v"(alpha_512)); \ + asm("vmovups %0, (%1)": : "v"(val0), "r"(addr)); + +#define _MASK_STORE_C_16(addr, val0) \ + asm("vfmadd213ps (%1), %2, %0 %{%3%} ": "+v"(val0) : "r"(addr), "v"(alpha_512), "k"(mmask)); \ + asm("vmovups %0, (%1) %{%2%}": : "v"(val0), "r"(addr), "k"(mmask)); + +#define N_STORE_4X(A, Bx, By) { \ + _REORDER_C_2X(result_00_##A##Bx##By, result_01_##A##Bx##By); \ + _REORDER_C_2X(result_10_##A##Bx##By, result_11_##A##Bx##By); \ + switch(n_count) { \ + case 3: _STORE_C_16(ptr_c + ldc * 2, result_10_##A##Bx##By); \ + case 2: _STORE_C_16(ptr_c + ldc * 1, result_01_##A##Bx##By); \ + case 1: _STORE_C_16(ptr_c + ldc * 0, result_00_##A##Bx##By); \ + } \ + ptr_c##A += ldc * n_count; \ +} + +#define N_MASK_STORE_4X(A, Bx, By) { \ + _REORDER_C_2X(result_00_##A##Bx##By, result_01_##A##Bx##By); \ + _REORDER_C_2X(result_10_##A##Bx##By, result_11_##A##Bx##By); \ + switch(n_count) { \ + case 3: _MASK_STORE_C_16(ptr_c + ldc * 2, result_10_##A##Bx##By); \ + case 2: _MASK_STORE_C_16(ptr_c + ldc * 1, result_01_##A##Bx##By); \ + case 1: _MASK_STORE_C_16(ptr_c + ldc * 0, result_00_##A##Bx##By); \ + } \ + ptr_c##A += ldc * n_count; \ +} + + +int CNAME (BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) +{ + IFLOAT *ptr_a = A, *ptr_b = B; + IFLOAT *ptr_b0, *ptr_b1; + IFLOAT *ptr_a0, *ptr_a1; + FLOAT *ptr_c = C; + FLOAT *ptr_c0, *ptr_c1; + BLASLONG n_count = n; + BLASLONG m_count, k_count; + BLASLONG n_blksize = 4 * k; + BLASLONG cn_offset = 0; + __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); + + for (; n_count > 23; n_count -= 24) { + IFLOAT *ptr_b00 = ptr_b; + IFLOAT *ptr_b10 = ptr_b + n_blksize * 3; + ptr_a0 = ptr_a; + ptr_c = C + cn_offset * ldc; + m_count = m; + for (; m_count > 15; m_count -= 16) { + ptr_b0 = ptr_b00; + ptr_b1 = ptr_b10; + DECLARE_A_PAIR(0); + DECLARE_B_PAIR(); + DECLARE_RESULT_4X(0, 0, 0); DECLARE_RESULT_4X(0, 0, 1); DECLARE_RESULT_4X(0, 0, 2); + DECLARE_RESULT_4X(0, 1, 0); DECLARE_RESULT_4X(0, 1, 1); DECLARE_RESULT_4X(0, 1, 2); + k_count = k; + for (; k_count > 3; k_count -=4) { + LOAD_A_PAIR(0); + _mm_prefetch(ptr_a0 + 128, _MM_HINT_T0); + ptr_a0 += 16 * 2; + BROADCAST_B_PAIR(0, 0); PREFETCH_B(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR(0, 1); PREFETCH_B(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR(0, 2); PREFETCH_B(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4 * 2; + BROADCAST_B_PAIR(1, 0); PREFETCH_B(1, 0); MATMUL_4X(0, 1, 0); + BROADCAST_B_PAIR(1, 1); PREFETCH_B(1, 1); MATMUL_4X(0, 1, 1); + BROADCAST_B_PAIR(1, 2); PREFETCH_B(1, 2); MATMUL_4X(0, 1, 2); + ptr_b1 += 4 * 2; + + LOAD_A_PAIR(0); + _mm_prefetch(ptr_a0 + 128, _MM_HINT_T0); + ptr_a0 += 16 * 2; + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4 * 2; + BROADCAST_B_PAIR(1, 0); MATMUL_4X(0, 1, 0); + BROADCAST_B_PAIR(1, 1); MATMUL_4X(0, 1, 1); + BROADCAST_B_PAIR(1, 2); MATMUL_4X(0, 1, 2); + ptr_b1 += 4 * 2; + } + for (; k_count > 1; k_count -=2) { + LOAD_A_PAIR(0); + ptr_a0 += 16 * 2; + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4 * 2; + BROADCAST_B_PAIR(1, 0); MATMUL_4X(0, 1, 0); + BROADCAST_B_PAIR(1, 1); MATMUL_4X(0, 1, 1); + BROADCAST_B_PAIR(1, 2); MATMUL_4X(0, 1, 2); + ptr_b1 += 4 * 2; + } + if (k_count > 0) { + LOAD_A_PAIR_TAIL(0); + ptr_a0 += 16; + BROADCAST_B_PAIR_TAIL(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR_TAIL(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR_TAIL(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4; + BROADCAST_B_PAIR_TAIL(1, 0); MATMUL_4X(0, 1, 0); + BROADCAST_B_PAIR_TAIL(1, 1); MATMUL_4X(0, 1, 1); + BROADCAST_B_PAIR_TAIL(1, 2); MATMUL_4X(0, 1, 2); + ptr_b1 += 4; + } + ptr_c0 = ptr_c; + STORE_4X(0, 0, 0); STORE_4X(0, 0, 1); STORE_4X(0, 0, 2); + STORE_4X(0, 1, 0); STORE_4X(0, 1, 1); STORE_4X(0, 1, 2); + ptr_c += 16; + } + if (m_count > 0) { + __mmask16 mmask = (1UL << m_count) - 1; + ptr_b0 = ptr_b00; + ptr_b1 = ptr_b10; + DECLARE_A_PAIR(0); + DECLARE_B_PAIR(); + DECLARE_RESULT_4X(0, 0, 0); DECLARE_RESULT_4X(0, 0, 1); DECLARE_RESULT_4X(0, 0, 2); + DECLARE_RESULT_4X(0, 1, 0); DECLARE_RESULT_4X(0, 1, 1); DECLARE_RESULT_4X(0, 1, 2); + for (k_count = k; k_count > 1; k_count -=2) { + MASK_LOAD_A_PAIR(0); + ptr_a0 += m_count * 2; + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4 * 2; + BROADCAST_B_PAIR(1, 0); MATMUL_4X(0, 1, 0); + BROADCAST_B_PAIR(1, 1); MATMUL_4X(0, 1, 1); + BROADCAST_B_PAIR(1, 2); MATMUL_4X(0, 1, 2); + ptr_b1 += 4 * 2; + } + if (k_count > 0) { + MASK_LOAD_A_PAIR_TAIL(0); + ptr_a0 += m_count; + BROADCAST_B_PAIR_TAIL(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR_TAIL(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR_TAIL(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4; + BROADCAST_B_PAIR_TAIL(1, 0); MATMUL_4X(0, 1, 0); + BROADCAST_B_PAIR_TAIL(1, 1); MATMUL_4X(0, 1, 1); + BROADCAST_B_PAIR_TAIL(1, 2); MATMUL_4X(0, 1, 2); + ptr_b1 += 4; + } + ptr_c0 = ptr_c; + MASK_STORE_4X(0, 0, 0); MASK_STORE_4X(0, 0, 1); MASK_STORE_4X(0, 0, 2); + MASK_STORE_4X(0, 1, 0); MASK_STORE_4X(0, 1, 1); MASK_STORE_4X(0, 1, 2); + ptr_c += m_count; + } + ptr_b += 24 * k; + cn_offset += 24; + } + for (; n_count > 11; n_count -= 12) { + IFLOAT *ptr_b00 = ptr_b; + ptr_a0 = ptr_a; + ptr_a1 = ptr_a + 16 * k; + ptr_c = C + cn_offset * ldc; + m_count = m; + for (; m_count > 31; m_count -= 32) { + ptr_b0 = ptr_b00; + DECLARE_A_PAIR(0); DECLARE_A_PAIR(1); + DECLARE_B_PAIR(); + DECLARE_RESULT_4X(0, 0, 0); DECLARE_RESULT_4X(0, 0, 1); DECLARE_RESULT_4X(0, 0, 2); + DECLARE_RESULT_4X(1, 0, 0); DECLARE_RESULT_4X(1, 0, 1); DECLARE_RESULT_4X(1, 0, 2); + for (k_count = k; k_count > 1; k_count -=2) { + LOAD_A_PAIR(0); LOAD_A_PAIR(1); + ptr_a0 += 16 * 2; + ptr_a1 += 16 * 2; + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); MATMUL_4X(1, 0, 0); + BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1); MATMUL_4X(1, 0, 1); + BROADCAST_B_PAIR(0, 2); MATMUL_4X(0, 0, 2); MATMUL_4X(1, 0, 2); + ptr_b0 += 4 * 2; + } + if (k_count > 0) { + LOAD_A_PAIR_TAIL(0); LOAD_A_PAIR_TAIL(1); + ptr_a0 += 16; + ptr_a1 += 16; + BROADCAST_B_PAIR_TAIL(0, 0); MATMUL_4X(0, 0, 0); MATMUL_4X(1, 0, 0); + BROADCAST_B_PAIR_TAIL(0, 1); MATMUL_4X(0, 0, 1); MATMUL_4X(1, 0, 1); + BROADCAST_B_PAIR_TAIL(0, 2); MATMUL_4X(0, 0, 2); MATMUL_4X(1, 0, 2); + ptr_b0 += 4; + } + ptr_c0 = ptr_c; + ptr_c1 = ptr_c + 16; + STORE_4X(0, 0, 0); STORE_4X(1, 0, 0); + STORE_4X(0, 0, 1); STORE_4X(1, 0, 1); + STORE_4X(0, 0, 2); STORE_4X(1, 0, 2); + ptr_c += 16 * 2; + ptr_a0 = ptr_a1; + ptr_a1 = ptr_a0 + 16 * k; + } + for (; m_count > 15; m_count -= 16) { + ptr_b0 = ptr_b00; + DECLARE_A_PAIR(0); + DECLARE_B_PAIR(); + DECLARE_RESULT_4X(0, 0, 0); DECLARE_RESULT_4X(0, 0, 1); DECLARE_RESULT_4X(0, 0, 2); + for (k_count = k; k_count > 1; k_count -=2) { + LOAD_A_PAIR(0); + ptr_a0 += 16 * 2; + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4 * 2; + } + if (k_count > 0) { + LOAD_A_PAIR_TAIL(0); + ptr_a0 += 16; + BROADCAST_B_PAIR_TAIL(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR_TAIL(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR_TAIL(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4; + } + ptr_c0 = ptr_c; + STORE_4X(0, 0, 0); STORE_4X(0, 0, 1); STORE_4X(0, 0, 2); + ptr_c += 16; + } + if (m_count > 0) { + __mmask16 mmask = (1UL << m_count) - 1; + ptr_b0 = ptr_b00; + DECLARE_A_PAIR(0); + DECLARE_B_PAIR(); + DECLARE_RESULT_4X(0, 0, 0); DECLARE_RESULT_4X(0, 0, 1); DECLARE_RESULT_4X(0, 0, 2); + for (k_count = k; k_count > 1; k_count -=2) { + MASK_LOAD_A_PAIR(0); + ptr_a0 += m_count * 2; + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4 * 2; + } + if (k_count > 0) { + MASK_LOAD_A_PAIR_TAIL(0); + ptr_a0 += m_count; + BROADCAST_B_PAIR_TAIL(0, 0); MATMUL_4X(0, 0, 0); + BROADCAST_B_PAIR_TAIL(0, 1); MATMUL_4X(0, 0, 1); + BROADCAST_B_PAIR_TAIL(0, 2); MATMUL_4X(0, 0, 2); + ptr_b0 += 4; + } + ptr_c0 = ptr_c; + MASK_STORE_4X(0, 0, 0); MASK_STORE_4X(0, 0, 1); MASK_STORE_4X(0, 0, 2); + ptr_c += m_count; + } + ptr_b += 12 * k; + cn_offset += 12; + } + for (; n_count > 3; n_count -= 4) { + IFLOAT *ptr_b00 = ptr_b; + ptr_a0 = ptr_a; + ptr_c = C + cn_offset * ldc; + m_count = m; + for (; m_count > 15; m_count -= 16) { + ptr_b0 = ptr_b00; + DECLARE_A_PAIR(0); + DECLARE_B_PAIR(); + DECLARE_RESULT_4X(0, 0, 0); + for (k_count = k; k_count > 1; k_count -=2) { + LOAD_A_PAIR(0); + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + ptr_b0 += 4 * 2; + ptr_a0 += 16 * 2; + } + if (k_count > 0) { + LOAD_A_PAIR_TAIL(0); + BROADCAST_B_PAIR_TAIL(0, 0); MATMUL_4X(0, 0, 0); + ptr_b0 += 4; + ptr_a0 += 16; + } + ptr_c0 = ptr_c; + STORE_4X(0, 0, 0); + ptr_c += 16; + } + if (m_count > 0) { + __mmask16 mmask = (1UL << m_count) - 1; + ptr_b0 = ptr_b00; + DECLARE_A_PAIR(0); + DECLARE_B_PAIR(); + DECLARE_RESULT_4X(0, 0, 0); + for (k_count = k; k_count > 1; k_count -=2) { + MASK_LOAD_A_PAIR(0); + BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + ptr_b0 += 4 * 2; + ptr_a0 += m_count * 2; + } + if (k_count > 0) { + MASK_LOAD_A_PAIR_TAIL(0); + BROADCAST_B_PAIR_TAIL(0, 0); MATMUL_4X(0, 0, 0); + ptr_b0 += 4; + ptr_a0 += m_count; + } + ptr_c0 = ptr_c; + MASK_STORE_4X(0, 0, 0); + ptr_c += m_count; + } + ptr_b += 4 * k; + cn_offset += 4; + } + if (n_count > 0) { + __mmask8 nmask = (1UL << n_count) - 1; + IFLOAT *ptr_b00 = ptr_b; + ptr_a0 = ptr_a; + ptr_c = C + cn_offset * ldc; + m_count = m; + for (; m_count > 15; m_count -= 16) { + ptr_b0 = ptr_b00; + DECLARE_A_PAIR(0); + DECLARE_B_PAIR(); + DECLARE_RESULT_4X(0, 0, 0); + for (k_count = k; k_count > 1; k_count -=2) { + LOAD_A_PAIR(0); + MASK_BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + ptr_b0 += n_count * 2; + ptr_a0 += 16 * 2; + } + if (k_count > 0) { + LOAD_A_PAIR_TAIL(0); + MASK_BROADCAST_B_PAIR_TAIL(0, 0); MATMUL_4X(0, 0, 0); + ptr_b0 += n_count; + ptr_a0 += 16; + } + ptr_c0 = ptr_c; + N_STORE_4X(0, 0, 0); + ptr_c += 16; + } + if (m_count > 0) { + __mmask16 mmask = (1UL << m_count) - 1; + ptr_b0 = ptr_b00; + DECLARE_A_PAIR(0); + DECLARE_B_PAIR(); + DECLARE_RESULT_4X(0, 0, 0); + for (k_count = k; k_count > 1; k_count -=2) { + MASK_LOAD_A_PAIR(0); + MASK_BROADCAST_B_PAIR(0, 0); MATMUL_4X(0, 0, 0); + ptr_b0 += n_count * 2; + ptr_a0 += m_count * 2; + } + if (k_count > 0) { + MASK_LOAD_A_PAIR_TAIL(0); + MASK_BROADCAST_B_PAIR_TAIL(0, 0); MATMUL_4X(0, 0, 0); + ptr_b0 += n_count; + ptr_a0 += m_count; + } + ptr_c0 = ptr_c; + N_MASK_STORE_4X(0, 0, 0); + ptr_c += m_count; + } + } + return 0; +} diff --git a/kernel/x86_64/sbgemm_microk_cooperlake_template.c b/kernel/x86_64/sbgemm_microk_cooperlake_template.c index dd4cb440b..b8ed9838e 100644 --- a/kernel/x86_64/sbgemm_microk_cooperlake_template.c +++ b/kernel/x86_64/sbgemm_microk_cooperlake_template.c @@ -1,46 +1,113 @@ -#include "sbgemm.h" #include "bf16_common_macros.h" #include +#define BF16_BLOCK_STEP_N 8 +#define BF16_BLOCK_THRES_K 1024 +#define BF16_BLOCK_THRES_M 32 +#define BF16_BLOCK_THRES_N 1024 + +#define A(i,j) A[(i)*lda+(j)] +#define B(i,j) B[(i)*ldb+(j)] +#define C(i,j) C[(i)*ldc+(j)] + +#define ONE 1.e0f +#define ZERO 0.e0f + #undef STORE16_COMPLETE_RESULT #undef STORE16_MASK_COMPLETE_RESULT -#undef SBGEMM_BLOCK_KERNEL_32x8x32 -#undef SBGEMM_BLOCK_KERNEL_16x8x32 -#undef SBGEMM_BLOCK_KERNEL_32xNx32 -#undef SBGEMM_BLOCK_KERNEL_16xNx32 -#undef SBGEMM_BLOCKING_KERNEL_2 +#undef SBGEMM_BLOCK_KERNEL_NN_32x8xK +#undef SBGEMM_BLOCK_KERNEL_NN_16x8xK +#undef SBGEMM_BLOCK_KERNEL_NN_32xNx32 +#undef SBGEMM_BLOCK_KERNEL_NN_16xNx32 +#undef SBGEMM_BLOCK_KERNEL_NT_32x8xK +#undef SBGEMM_BLOCK_KERNEL_NT_16x8xK +#undef SBGEMM_BLOCK_KERNEL_NT_32xNxK +#undef SBGEMM_BLOCK_KERNEL_NT_16xNxK +#undef SBGEMM_BLOCK_KERNEL_TN_32x8xK +#undef SBGEMM_BLOCK_KERNEL_TN_16x8xK +#undef SBGEMM_BLOCK_KERNEL_TN_32xNx32 +#undef SBGEMM_BLOCK_KERNEL_TN_16xNx32 +#undef SBGEMM_BLOCK_KERNEL_TT_32x8xK +#undef SBGEMM_BLOCK_KERNEL_TT_16x8xK +#undef SBGEMM_BLOCK_KERNEL_TT_32xNxK +#undef SBGEMM_BLOCK_KERNEL_TT_16xNxK +#undef SBGEMM_BLOCKING_KERNEL_NN +#undef SBGEMM_BLOCKING_KERNEL_NT +#undef SBGEMM_BLOCKING_KERNEL_TN +#undef SBGEMM_BLOCKING_KERNEL_TT #ifndef ONE_ALPHA // ALPHA is not ONE - #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE - #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE - #define SBGEMM_BLOCK_KERNEL_32x8x32 sbgemm_block_kernel_32x8x32_alpha - #define SBGEMM_BLOCK_KERNEL_16x8x32 sbgemm_block_kernel_16x8x32_alpha - #define SBGEMM_BLOCK_KERNEL_32xNx32 sbgemm_block_kernel_32xNx32_alpha - #define SBGEMM_BLOCK_KERNEL_16xNx32 sbgemm_block_kernel_16xNx32_alpha - #define SBGEMM_BLOCKING_KERNEL_2 sbgemm_blocking_kernel_2_alpha + #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ALPHA_ONE + #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ALPHA_ONE + + #define SBGEMM_BLOCK_KERNEL_NN_32x8xK sbgemm_block_kernel_nn_32x8xK_alpha + #define SBGEMM_BLOCK_KERNEL_NN_16x8xK sbgemm_block_kernel_nn_16x8xK_alpha + #define SBGEMM_BLOCK_KERNEL_NN_32xNx32 sbgemm_block_kernel_nn_32xNx32_alpha + #define SBGEMM_BLOCK_KERNEL_NN_16xNx32 sbgemm_block_kernel_nn_16xNx32_alpha + + #define SBGEMM_BLOCK_KERNEL_NT_32x8xK SBGEMM_BLOCK_KERNEL_NN_32x8xK + #define SBGEMM_BLOCK_KERNEL_NT_16x8xK SBGEMM_BLOCK_KERNEL_NN_16x8xK + #define SBGEMM_BLOCK_KERNEL_NT_32xNxK sbgemm_block_kernel_nt_32xNxK_alpha + #define SBGEMM_BLOCK_KERNEL_NT_16xNxK sbgemm_block_kernel_nt_16xNxK_alpha + + #define SBGEMM_BLOCK_KERNEL_TN_32x8xK sbgemm_block_kernel_tn_32x8xK_alpha + #define SBGEMM_BLOCK_KERNEL_TN_16x8xK sbgemm_block_kernel_tn_16x8xK_alpha + #define SBGEMM_BLOCK_KERNEL_TN_32xNx32 sbgemm_block_kernel_tn_32xNx32_alpha + #define SBGEMM_BLOCK_KERNEL_TN_16xNx32 sbgemm_block_kernel_tn_16xNx32_alpha + + #define SBGEMM_BLOCK_KERNEL_TT_32x8xK SBGEMM_BLOCK_KERNEL_TN_32x8xK + #define SBGEMM_BLOCK_KERNEL_TT_16x8xK SBGEMM_BLOCK_KERNEL_TN_16x8xK + #define SBGEMM_BLOCK_KERNEL_TT_32xNxK sbgemm_block_kernel_tt_32xNxK_alpha + #define SBGEMM_BLOCK_KERNEL_TT_16xNxK sbgemm_block_kernel_tt_16xNxK_alpha + + #define SBGEMM_BLOCKING_KERNEL_NN sbgemm_blocking_kernel_nn_alpha + #define SBGEMM_BLOCKING_KERNEL_NT sbgemm_blocking_kernel_nt_alpha + #define SBGEMM_BLOCKING_KERNEL_TN sbgemm_blocking_kernel_tn_alpha + #define SBGEMM_BLOCKING_KERNEL_TT sbgemm_blocking_kernel_tt_alpha #else // ALPHA is ONE - #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ONE_ONE - #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ONE_ONE - #define SBGEMM_BLOCK_KERNEL_32x8x32 sbgemm_block_kernel_32x8x32_one - #define SBGEMM_BLOCK_KERNEL_16x8x32 sbgemm_block_kernel_16x8x32_one - #define SBGEMM_BLOCK_KERNEL_32xNx32 sbgemm_block_kernel_32xNx32_one - #define SBGEMM_BLOCK_KERNEL_16xNx32 sbgemm_block_kernel_16xNx32_one - #define SBGEMM_BLOCKING_KERNEL_2 sbgemm_blocking_kernel_2_one + #define STORE16_COMPLETE_RESULT STORE16_COMPLETE_RESULT_ONE_ONE + #define STORE16_MASK_COMPLETE_RESULT STORE16_MASK_COMPLETE_RESULT_ONE_ONE + + #define SBGEMM_BLOCK_KERNEL_NN_32x8xK sbgemm_block_kernel_nn_32x8xK_one + #define SBGEMM_BLOCK_KERNEL_NN_16x8xK sbgemm_block_kernel_nn_16x8xK_one + #define SBGEMM_BLOCK_KERNEL_NN_32xNx32 sbgemm_block_kernel_nn_32xNx32_one + #define SBGEMM_BLOCK_KERNEL_NN_16xNx32 sbgemm_block_kernel_nn_16xNx32_one + + #define SBGEMM_BLOCK_KERNEL_NT_32x8xK SBGEMM_BLOCK_KERNEL_NN_32x8xK + #define SBGEMM_BLOCK_KERNEL_NT_16x8xK SBGEMM_BLOCK_KERNEL_NN_16x8xK + #define SBGEMM_BLOCK_KERNEL_NT_32xNxK sbgemm_block_kernel_nt_32xNxK_one + #define SBGEMM_BLOCK_KERNEL_NT_16xNxK sbgemm_block_kernel_nt_16xNxK_one + + #define SBGEMM_BLOCK_KERNEL_TN_32x8xK sbgemm_block_kernel_tn_32x8xK_one + #define SBGEMM_BLOCK_KERNEL_TN_16x8xK sbgemm_block_kernel_tn_16x8xK_one + #define SBGEMM_BLOCK_KERNEL_TN_32xNx32 sbgemm_block_kernel_tn_32xNx32_one + #define SBGEMM_BLOCK_KERNEL_TN_16xNx32 sbgemm_block_kernel_tn_16xNx32_one + + #define SBGEMM_BLOCK_KERNEL_TT_32x8xK SBGEMM_BLOCK_KERNEL_TN_32x8xK + #define SBGEMM_BLOCK_KERNEL_TT_16x8xK SBGEMM_BLOCK_KERNEL_TN_16x8xK + #define SBGEMM_BLOCK_KERNEL_TT_32xNxK sbgemm_block_kernel_tt_32xNxK_one + #define SBGEMM_BLOCK_KERNEL_TT_16xNxK sbgemm_block_kernel_tt_16xNxK_one + + #define SBGEMM_BLOCKING_KERNEL_NN sbgemm_blocking_kernel_nn_one + #define SBGEMM_BLOCKING_KERNEL_NT sbgemm_blocking_kernel_nt_one + #define SBGEMM_BLOCKING_KERNEL_TN sbgemm_blocking_kernel_tn_one + #define SBGEMM_BLOCKING_KERNEL_TT sbgemm_blocking_kernel_tt_one #endif +extern bfloat16 * block_A; +extern bfloat16 * block_B; +/* --------------------------------------------- NN kernels ------------------------------------------ */ // SBGEMM Kernel for 16> (16-m)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8); result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8); result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8); result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8); - STORE16_MASK_COMPLETE_RESULT(result_512_0, (&C[ldc*0]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_1, (&C[ldc*1]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_2, (&C[ldc*2]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_3, (&C[ldc*3]), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_0, (C_addr), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3), tail_mask) result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8); result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8); result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8); result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8); - STORE16_MASK_COMPLETE_RESULT(result_512_4, (&C[ldc*4]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_5, (&C[ldc*5]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_6, (&C[ldc*6]), tail_mask) - STORE16_MASK_COMPLETE_RESULT(result_512_7, (&C[ldc*7]), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7), tail_mask) } else { result_512_0 = _mm512_shuffle_f32x4(result_512_0, result_512_0, 0xd8); result_512_1 = _mm512_shuffle_f32x4(result_512_1, result_512_1, 0xd8); result_512_2 = _mm512_shuffle_f32x4(result_512_2, result_512_2, 0xd8); result_512_3 = _mm512_shuffle_f32x4(result_512_3, result_512_3, 0xd8); - STORE16_COMPLETE_RESULT(result_512_0, (&C[ldc*0])) - STORE16_COMPLETE_RESULT(result_512_1, (&C[ldc*1])) - STORE16_COMPLETE_RESULT(result_512_2, (&C[ldc*2])) - STORE16_COMPLETE_RESULT(result_512_3, (&C[ldc*3])) + STORE16_COMPLETE_RESULT(result_512_0, (C_addr)) + STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc*1)) + STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2)) + STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3)) result_512_4 = _mm512_shuffle_f32x4(result_512_4, result_512_4, 0xd8); result_512_5 = _mm512_shuffle_f32x4(result_512_5, result_512_5, 0xd8); result_512_6 = _mm512_shuffle_f32x4(result_512_6, result_512_6, 0xd8); result_512_7 = _mm512_shuffle_f32x4(result_512_7, result_512_7, 0xd8); - STORE16_COMPLETE_RESULT(result_512_4, (&C[ldc*4])) - STORE16_COMPLETE_RESULT(result_512_5, (&C[ldc*5])) - STORE16_COMPLETE_RESULT(result_512_6, (&C[ldc*6])) - STORE16_COMPLETE_RESULT(result_512_7, (&C[ldc*7])) + STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4)) + STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5)) + STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6)) + STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7)) } } // SBGEMM Kernel for 16> (32-m)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + unsigned short tail_mask = (((unsigned short)0xffff) >> (32-m)); for (int i = 0; i < n; i++) { result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]); result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]); - STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i])) - STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16]), tail_mask) + STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i)) + STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask) } } else { for (int i = 0; i < n; i++) { result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]); result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]); - STORE16_COMPLETE_RESULT(result_512_tmp_0, (&C[ldc*i])) - STORE16_COMPLETE_RESULT(result_512_tmp_1, (&C[ldc*i+16])) + STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i)) + STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16)) } } } // SBGEMM Kernel for 16<=M, N<8, K can be any number, but the processing will take 32 as a base #ifndef ONE_ALPHA // ALPHA is not ONE -void sbgemm_block_kernel_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +void sbgemm_block_kernel_nn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) #else // ALPHA is ONE -void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +void sbgemm_block_kernel_nn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) #endif { + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + int SHUFFLE_MAGIC_NO = 0x39; BLASLONG tag_k_32x = k & (~31); - BLASLONG idxB_base = 0; - BLASLONG width = 32; #ifndef ONE_ALPHA __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); @@ -432,19 +481,18 @@ void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float a result_512[i+1] = _mm512_setzero_ps(); } - for (BLASLONG idx_k = 0; idx_k < k; idx_k += 32) { + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { // Load B with unroll n - for (int i = 0; i < n; i ++) { - arrayB_512[i] = _mm512_loadu_si512(&B[idxB_base]); - idxB_base += 32; + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_loadu_si512(B_addr); + B_addr += 32; } - if (idx_k == tag_k_32x) {width = k - tag_k_32x;} - - for (BLASLONG idx = 0; idx < width;) { + for (BLASLONG idx = 0; idx < 32;) { // Each two rows are a group for 32-pair bf16 elements // Load two rows into a 512 register - arrayA_512 = _mm512_loadu_si512(&A[idx<<4]); + arrayA_512 = _mm512_loadu_si512(A_addr); + A_addr += 32; for (int i = 0; i < n; i ++) { result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i]))); @@ -461,24 +509,54 @@ void sbgemm_block_kernel_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float a } } + if (tag_k_32x != k) { + // Load B with unroll n + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_loadu_si512(B_addr); + B_addr += 32; + } + + BLASLONG width = k - tag_k_32x; + for (BLASLONG idx = 0; idx < width;) { + // Each two rows are a group for 32-pair bf16 elements + // Load two rows into a 512 register + arrayA_512 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i]))); + arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO); + } + + idx += 2; + // Every 4 loops we need to switch to next 128 bits of arrayB registers + if ((idx & (~7)) == idx) { + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO); + } + } + } + } + if (m != 16) { - unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m)); - __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); for (int i = 0; i < n; i++) { result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8); - STORE16_MASK_COMPLETE_RESULT(result_512[i], (&C[ldc*i]), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask) } } else { for (int i = 0; i < n; i++) { result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8); - STORE16_COMPLETE_RESULT(result_512[i], (&C[ldc*i])) + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) } } } + + #ifndef ONE_ALPHA // ALPHA is not ONE -void sbgemm_blocking_kernel_2_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +void sbgemm_blocking_kernel_nn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) #else // ALPHA is ONE -void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +void sbgemm_blocking_kernel_nn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) #endif { BLASLONG m_step, n_step, k_step, k_step_round32; @@ -499,63 +577,52 @@ void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, while (n_from < N) { for (BLASLONG idx_k = 0; idx_k < K;) { // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... - COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, 0), lda, block_A); - // TODO: MT + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); - SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); - SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); } for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) { - COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, &A(idx_k, idx_m), lda, block_A); + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { - SBGEMM_BLOCK_KERNEL_32x8x32(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); + SBGEMM_BLOCK_KERNEL_NN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; - SBGEMM_BLOCK_KERNEL_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); + SBGEMM_BLOCK_KERNEL_NN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); } } if (tag_m_Nx != M) { m_step = M - tag_m_Nx; if (m_step > 16) { - COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { - SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; - SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); - } - } else if (m_step == 16) { - COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); - for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { - SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); - } - - if (tag_n_Nx != n_to) { - n_step = n_to - tag_n_Nx; - SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); } } else { - COL_MAJOR_INCOPY_KERNEL_Kx16m(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); + COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { - SBGEMM_BLOCK_KERNEL_16x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; - SBGEMM_BLOCK_KERNEL_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); } } } @@ -573,22 +640,274 @@ void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); } } else { - m_step = M - tag_m_Nx; + m_step = M; + if (m_step > 16) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } + } +} +/* ----------------------------------------- End of NN kernels --------------------------------------- */ + +/* --------------------------------------------- NT kernels ------------------------------------------ */ +// SBGEMM Kernel for 16> (32-m)); + for (int i = 0; i < n; i ++) { + result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]); + result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]); + STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i)) + STORE16_MASK_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16), tail_mask) + } + } else { + for (int i = 0; i < n; i ++) { + result_512_tmp_0 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base0, result_512[i+8]); + result_512_tmp_1 = _mm512_permutex2var_ps(result_512[i], shuffle_idx_base1, result_512[i+8]); + STORE16_COMPLETE_RESULT(result_512_tmp_0, (C_addr + ldc*i)) + STORE16_COMPLETE_RESULT(result_512_tmp_1, (C_addr + ldc*i + 16)) + } + } +} + +// SBGEMM Kernel for M<=16, N<8, K can be any number +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_block_kernel_nt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#else // ALPHA is ONE +void sbgemm_block_kernel_nt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#endif +{ + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + +#ifndef ONE_ALPHA + __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); +#endif + + __m512i arrayA_512_0; + __m512i arrayB_512[8]; + __m512 result_512[8]; + + result_512[0] = _mm512_setzero_ps(); + result_512[1] = _mm512_setzero_ps(); + result_512[2] = _mm512_setzero_ps(); + result_512[3] = _mm512_setzero_ps(); + result_512[4] = _mm512_setzero_ps(); + result_512[5] = _mm512_setzero_ps(); + result_512[6] = _mm512_setzero_ps(); + result_512[7] = _mm512_setzero_ps(); + + for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) { + // Each two rows are a group for 16-pair bf16 elements + // Load two rows into a 512 register + arrayA_512_0 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i ++) { + _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]); + } + B_addr += 16; + + for (int i = 0; i < n; i ++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]); + } + } + + if (m != 16) { + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); + for (int i = 0; i < n; i++) { + result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8); + STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask) + } + } else { + for (int i = 0; i < n; i++) { + result_512[i] = _mm512_shuffle_f32x4(result_512[i], result_512[i], 0xd8); + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + } + } +} + +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_blocking_kernel_nt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#else // ALPHA is ONE +void sbgemm_blocking_kernel_nt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#endif +{ + BLASLONG m_step, n_step, k_step, k_step_round32; + BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1)); + + BLASLONG n_from, n_to; + BLASLONG tag_n_Nx; + + n_from = 0; + n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + + k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + + if (M >= BF16_BLOCK_THRES_M) { while (n_from < N) { for (BLASLONG idx_k = 0; idx_k < K;) { // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... - COL_MAJOR_INCOPY_KERNEL_Kx32m(k_step, m_step, &A(idx_k, 0), lda, block_A); - // TODO: MT + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, 0), lda, block_A); for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... - COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); - SBGEMM_BLOCK_KERNEL_32x8x32(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); } if (tag_n_Nx != n_to) { n_step = n_to - tag_n_Nx; - COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); - SBGEMM_BLOCK_KERNEL_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) { + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, 32, &A(idx_k, idx_m), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_NT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_NT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); + } + } + + if (tag_m_Nx != M) { + m_step = M - tag_m_Nx; + if (m_step > 16) { + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } else { + COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, tag_m_Nx), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } } idx_k += k_step; @@ -597,14 +916,886 @@ void sbgemm_blocking_kernel_2_one(blasint M, blasint N, blasint K, float alpha, k_step_round32 = k_step & (~31); k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; } + n_from = n_to; n_to += BF16_BLOCK_THRES_N; n_to = (n_to > N) ? N : n_to; tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); } + } else { + m_step = M; + if (m_step > 16) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_INCOPY_KERNEL_Kx32(k_step, m_step, &A(idx_k, 0), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_INCOPY_KERNEL_Kx16(k_step, m_step, &A(idx_k, 0), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_NT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } + } +} +/* ----------------------------------------- End of NT kernels --------------------------------------- */ + +/* --------------------------------------------- TN kernels ------------------------------------------ */ +// SBGEMM Kernel for 16> (32-m)); + __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + STORE16_COMPLETE_RESULT(result_512_0, (C_addr)) + STORE16_MASK_COMPLETE_RESULT(result_512_8, (C_addr + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc)) + STORE16_MASK_COMPLETE_RESULT(result_512_9, (C_addr + ldc + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2)) + STORE16_MASK_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3)) + STORE16_MASK_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4)) + STORE16_MASK_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5)) + STORE16_MASK_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6)) + STORE16_MASK_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16), tail_mask) + STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7)) + STORE16_MASK_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16), tail_mask) + } else { + STORE16_COMPLETE_RESULT(result_512_0, (C_addr)) + STORE16_COMPLETE_RESULT(result_512_8, (C_addr + 16)) + STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc)) + STORE16_COMPLETE_RESULT(result_512_9, (C_addr + ldc + 16)) + STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2)) + STORE16_COMPLETE_RESULT(result_512_10, (C_addr + ldc*2 + 16)) + STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3)) + STORE16_COMPLETE_RESULT(result_512_11, (C_addr + ldc*3 + 16)) + STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4)) + STORE16_COMPLETE_RESULT(result_512_12, (C_addr + ldc*4 + 16)) + STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5)) + STORE16_COMPLETE_RESULT(result_512_13, (C_addr + ldc*5 + 16)) + STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6)) + STORE16_COMPLETE_RESULT(result_512_14, (C_addr + ldc*6 + 16)) + STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7)) + STORE16_COMPLETE_RESULT(result_512_15, (C_addr + ldc*7 + 16)) } } +// SBGEMM Kernel for M=16, N=8, K=Any number +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_block_kernel_tn_16x8xK_alpha(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#else // ALPHA is ONE +void sbgemm_block_kernel_tn_16x8xK_one(BLASLONG m, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#endif +{ + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + +#ifndef ONE_ALPHA + __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); +#endif + + __m512i arrayA_512_0; + __m512i arrayB_512_0, arrayB_512_1, arrayB_512_2, arrayB_512_3, arrayB_512_4, arrayB_512_5, arrayB_512_6, arrayB_512_7; + __m512 result_512_0, result_512_1, result_512_2, result_512_3, result_512_4, result_512_5, result_512_6, result_512_7; + + result_512_0 = _mm512_setzero_ps(); + result_512_1 = _mm512_setzero_ps(); + result_512_2 = _mm512_setzero_ps(); + result_512_3 = _mm512_setzero_ps(); + result_512_4 = _mm512_setzero_ps(); + result_512_5 = _mm512_setzero_ps(); + result_512_6 = _mm512_setzero_ps(); + result_512_7 = _mm512_setzero_ps(); + + for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) { + // Load 16 pair of BF16 elements from A (16 rows) + arrayA_512_0 = _mm512_loadu_si512(A_addr + 0); + + // Load 8 rows of B + _MM512_BROADCASTD_EPI32(B_addr + 0, arrayB_512_0); + _MM512_BROADCASTD_EPI32(B_addr + 2, arrayB_512_1); + _MM512_BROADCASTD_EPI32(B_addr + 4, arrayB_512_2); + _MM512_BROADCASTD_EPI32(B_addr + 6, arrayB_512_3); + _MM512_BROADCASTD_EPI32(B_addr + 8, arrayB_512_4); + _MM512_BROADCASTD_EPI32(B_addr + 10, arrayB_512_5); + _MM512_BROADCASTD_EPI32(B_addr + 12, arrayB_512_6); + _MM512_BROADCASTD_EPI32(B_addr + 14, arrayB_512_7); + + result_512_0 = _mm512_dpbf16_ps(result_512_0, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_0); + result_512_1 = _mm512_dpbf16_ps(result_512_1, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_1); + result_512_2 = _mm512_dpbf16_ps(result_512_2, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_2); + result_512_3 = _mm512_dpbf16_ps(result_512_3, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_3); + result_512_4 = _mm512_dpbf16_ps(result_512_4, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_4); + result_512_5 = _mm512_dpbf16_ps(result_512_5, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_5); + result_512_6 = _mm512_dpbf16_ps(result_512_6, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_6); + result_512_7 = _mm512_dpbf16_ps(result_512_7, (__m512bh) arrayA_512_0, (__m512bh) arrayB_512_7); + + // Load B with unroll 8 + B_addr += 16; + // Load A with unroll 32 + A_addr += 32; + } + + if (m != 16) { + unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-m)); + __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); + STORE16_MASK_COMPLETE_RESULT(result_512_0, (C_addr), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_1, (C_addr + ldc), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6), tail_mask) + STORE16_MASK_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7), tail_mask) + } else { + STORE16_COMPLETE_RESULT(result_512_0, (C_addr)) + STORE16_COMPLETE_RESULT(result_512_1, (C_addr + ldc)) + STORE16_COMPLETE_RESULT(result_512_2, (C_addr + ldc*2)) + STORE16_COMPLETE_RESULT(result_512_3, (C_addr + ldc*3)) + STORE16_COMPLETE_RESULT(result_512_4, (C_addr + ldc*4)) + STORE16_COMPLETE_RESULT(result_512_5, (C_addr + ldc*5)) + STORE16_COMPLETE_RESULT(result_512_6, (C_addr + ldc*6)) + STORE16_COMPLETE_RESULT(result_512_7, (C_addr + ldc*7)) + } +} + +// SBGEMM Kernel for 16> (32-m)); + for (int i = 0; i < n; i++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask) + } + } else { + for (int i = 0; i < n; i++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16)) + } + } +} + +// SBGEMM Kernel for M<=16, N<8, K=Any number but will be processed based on 32 +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_block_kernel_tn_16xNx32_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#else // ALPHA is ONE +void sbgemm_block_kernel_tn_16xNx32_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#endif +{ + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + + int SHUFFLE_MAGIC_NO = 0x39; + BLASLONG tag_k_32x = k & (~31); + +#ifndef ONE_ALPHA + __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); +#endif + + __m512i arrayA_512; + __m512i arrayB_512[8]; + __m512 result_512[8]; + + for (int i = 0; i < 8; i++) { + result_512[i] = _mm512_setzero_ps(); + } + + for (BLASLONG idx_k = 0; idx_k < tag_k_32x; idx_k += 32) { + // Load B with unroll n + for (int i = 0; i < n; i ++) { + arrayB_512[i] = _mm512_loadu_si512(B_addr); + B_addr += 32; + } + + for (BLASLONG idx = 0; idx < 32;) { + // Each two rows are a group for 32-pair bf16 elements + arrayA_512 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i]))); + arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO); + } + + idx += 2; + // Every 4 loops we need to switch to next 128 bits of arrayB registers + if ((idx & (~7)) == idx) { + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO); + } + } + } + } + + if (tag_k_32x != k) { + // Load B with unroll n + for (int i = 0; i < n; i ++) { + arrayB_512[i] = _mm512_loadu_si512(B_addr); + B_addr += 32; + } + + BLASLONG width = k - tag_k_32x; + for (BLASLONG idx = 0; idx < width;) { + // Each two rows are a group for 32-pair bf16 elements + arrayA_512 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512, (__m512bh) _mm512_broadcastd_epi32(_mm512_castsi512_si128(arrayB_512[i]))); + arrayB_512[i] = _mm512_shuffle_epi32(arrayB_512[i], SHUFFLE_MAGIC_NO); + } + + idx += 2; + // Every 4 loops we need to switch to next 128 bits of arrayB registers + if ((idx & (~7)) == idx) { + for (int i = 0; i < n; i++) { + arrayB_512[i] = _mm512_shuffle_i32x4(arrayB_512[i], arrayB_512[i], SHUFFLE_MAGIC_NO); + } + } + } + } + + if (m != 16) { + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); + for (int i = 0; i < n; i++) { + STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask) + } + } else { + for (int i = 0; i < n; i++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + } + } +} + +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_blocking_kernel_tn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#else // ALPHA is ONE +void sbgemm_blocking_kernel_tn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#endif +{ + BLASLONG m_step, n_step, k_step, k_step_round32; + BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1)); + + BLASLONG n_from, n_to; + BLASLONG tag_n_Nx; + + n_from = 0; + n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + + k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + + if (M >= BF16_BLOCK_THRES_M) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); // TODO how to process m + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) { + COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TN_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TN_32xNx32(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); + } + } + + if (tag_m_Nx != M) { + m_step = M - tag_m_Nx; + if (m_step > 16) { + COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } else { + COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + m_step = M; + if (m_step > 16) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_32xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_ONCOPY_KERNEL_8x32(k_step, &B(idx_n, idx_k), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_ONCOPY_KERNEL_Nx32(n_step, k_step, &B(tag_n_Nx, idx_k), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TN_16xNx32(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } + } +} +/* ----------------------------------------- End of TN kernels --------------------------------------- */ + +/* --------------------------------------------- TT kernels ------------------------------------------ */ +// SBGEMM Kernel for 16> (32-m)); + for (int i = 0; i < n; i ++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + STORE16_MASK_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16), tail_mask) + } + } else { + for (int i = 0; i < n; i ++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + STORE16_COMPLETE_RESULT(result_512[i+8], (C_addr + ldc*i + 16)) + } + } +} + +// SBGEMM Kernel for M<=16, N<8, K can be any number +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_block_kernel_tt_16xNxK_alpha(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#else // ALPHA is ONE +void sbgemm_block_kernel_tt_16xNxK_one(BLASLONG m, BLASLONG n, BLASLONG k, float alpha, bfloat16 *A, bfloat16 *B, float *C, int ldc) +#endif +{ + bfloat16 * A_addr = A; + bfloat16 * B_addr = B; + float * C_addr = C; + +#ifndef ONE_ALPHA + __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); +#endif + + __m512i arrayA_512_0; + __m512i arrayB_512[8]; + __m512 result_512[8]; + + result_512[0] = _mm512_setzero_ps(); + result_512[1] = _mm512_setzero_ps(); + result_512[2] = _mm512_setzero_ps(); + result_512[3] = _mm512_setzero_ps(); + result_512[4] = _mm512_setzero_ps(); + result_512[5] = _mm512_setzero_ps(); + result_512[6] = _mm512_setzero_ps(); + result_512[7] = _mm512_setzero_ps(); + + for (BLASLONG idx_k = 0; idx_k < k; idx_k += 2) { + // Each two rows are a group for 16-pair bf16 elements + // Load two rows into a 512 register + arrayA_512_0 = _mm512_loadu_si512(A_addr); + A_addr += 32; + + for (int i = 0; i < n; i ++) { + _MM512_BROADCASTD_EPI32(B_addr + i*2, arrayB_512[i]); + } + B_addr += 16; + + for (int i = 0; i < n; i ++) { + result_512[i] = _mm512_dpbf16_ps(result_512[i], (__m512bh) arrayA_512_0, (__m512bh) arrayB_512[i]); + } + } + + if (m != 16) { + unsigned short tail_mask = (((unsigned short)0xffff) >> (16-m)); + for (int i = 0; i < n; i++) { + STORE16_MASK_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i), tail_mask) + } + } else { + for (int i = 0; i < n; i++) { + STORE16_COMPLETE_RESULT(result_512[i], (C_addr + ldc*i)) + } + } +} + +#ifndef ONE_ALPHA // ALPHA is not ONE +void sbgemm_blocking_kernel_tt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#else // ALPHA is ONE +void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B) +#endif +{ + BLASLONG m_step, n_step, k_step, k_step_round32; + BLASLONG tag_m_Nx = M & (~(BF16_BLOCK_THRES_M-1)); + + BLASLONG n_from, n_to; + BLASLONG tag_n_Nx; + + n_from = 0; + n_to = (BF16_BLOCK_THRES_N > N) ? N : BF16_BLOCK_THRES_N; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + + k_step = (K > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : K; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + + if (M >= BF16_BLOCK_THRES_M) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + for (BLASLONG idx_m = BF16_BLOCK_THRES_M; idx_m < tag_m_Nx; idx_m += BF16_BLOCK_THRES_M) { + COL_MAJOR_ITCOPY_KERNEL_Kx32(k_step, &A(idx_m, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TT_32x8xK(32, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, idx_m), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TT_32xNxK(32, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, idx_m), ldc); + } + } + + if (tag_m_Nx != M) { + m_step = M - tag_m_Nx; + if (m_step > 16) { + COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } else { + COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(tag_m_Nx, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, tag_m_Nx), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, tag_m_Nx), ldc); + } + } + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + m_step = M; + if (m_step > 16) { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx32m(m_step, k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_32x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_32xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } else { + while (n_from < N) { + for (BLASLONG idx_k = 0; idx_k < K;) { + // Use Kx32 kernel when BF16_BLOCK_THRES_M==32, Kx16 kernel when BF16_BLOCK_THRES_M==16, ... + COL_MAJOR_ITCOPY_KERNEL_Kx16m(m_step, k_step, &A(0, idx_k), lda, block_A); + for (BLASLONG idx_n = n_from; idx_n < tag_n_Nx; idx_n += BF16_BLOCK_STEP_N) { + // Use 8x32 kernel when BF16_BLOCK_THRES_N==8, 4x32 kernel when BF16_BLOCK_THRES_N==4, ... + COL_MAJOR_OTCOPY_KERNEL_Kx8(k_step, &B(idx_k, idx_n), ldb, block_B + (idx_n-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_16x8xK(m_step, k_step, alpha, block_A, block_B + (idx_n-n_from)*k_step_round32, &C(idx_n, 0), ldc); + } + + if (tag_n_Nx != n_to) { + n_step = n_to - tag_n_Nx; + COL_MAJOR_OTCOPY_KERNEL_Kx8m(k_step, n_step, &B(idx_k, tag_n_Nx), ldb, block_B + (tag_n_Nx-n_from)*k_step_round32); + SBGEMM_BLOCK_KERNEL_TT_16xNxK(m_step, n_step, k_step, alpha, block_A, block_B + (tag_n_Nx-n_from)*k_step_round32, &C(tag_n_Nx, 0), ldc); + } + + idx_k += k_step; + k_step = K - idx_k; + k_step = (k_step > BF16_BLOCK_THRES_K) ? BF16_BLOCK_THRES_K : k_step; + k_step_round32 = k_step & (~31); + k_step_round32 = (k_step > k_step_round32) ? (k_step_round32 + 32) : k_step_round32; + } + n_from = n_to; + n_to += BF16_BLOCK_THRES_N; + n_to = (n_to > N) ? N : n_to; + tag_n_Nx = n_to & (~(BF16_BLOCK_STEP_N-1)); + } + } + } +} +/* ----------------------------------------- End of TT kernels --------------------------------------- */ + +/* #ifndef ONE_ALPHA // ALPHA is not ONE void sbgemm_internal_kernel_alpha(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc) @@ -613,13 +1804,34 @@ void sbgemm_internal_kernel_one(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_ OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *B, OPENBLAS_CONST blasint ldb, float *C, OPENBLAS_CONST blasint ldc) #endif { - bfloat16 block_A[BF16_BLOCK_THRES_K * BF16_BLOCK_THRES_M]; - bfloat16 block_B[BF16_BLOCK_THRES_N * BF16_BLOCK_THRES_K]; - - // TODO: assume no trans for both A and B, to complement these scenarios later if (Order == CblasColMajor) { - SBGEMM_BLOCKING_KERNEL_2(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + if (TransA == CblasNoTrans) { + if (TransB == CblasNoTrans) { + SBGEMM_BLOCKING_KERNEL_NN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } else if (TransB == CblasTrans) { + SBGEMM_BLOCKING_KERNEL_NT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } + } else { + if (TransB == CblasNoTrans) { + SBGEMM_BLOCKING_KERNEL_TN(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } else if (TransB == CblasTrans) { + SBGEMM_BLOCKING_KERNEL_TT(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } + } } else { - + if (TransA == CblasNoTrans) { + if (TransB == CblasNoTrans) { + SBGEMM_BLOCKING_KERNEL_NN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B); + } else if (TransB == CblasTrans) { + SBGEMM_BLOCKING_KERNEL_TN(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B); + } + } else { + if (TransB == CblasNoTrans) { + SBGEMM_BLOCKING_KERNEL_NT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B); + } else if (TransB == CblasTrans) { + SBGEMM_BLOCKING_KERNEL_TT(N, M, K, alpha, B, ldb, A, lda, C, ldc, block_A, block_B); + } + } } -} \ No newline at end of file +} +*/ diff --git a/kernel/x86_64/sbgemm_ncopy_16_cooperlake.c b/kernel/x86_64/sbgemm_ncopy_16_cooperlake.c new file mode 100644 index 000000000..95ed82d7c --- /dev/null +++ b/kernel/x86_64/sbgemm_ncopy_16_cooperlake.c @@ -0,0 +1,353 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include +#include "common.h" + +#define _MM512_SHUFFLE_i32(result, in1, in2, imm8) \ + asm("vshufps %3, %2, %1, %0": "=v"(result): "v"(in1), "v"(in2), "N"(imm8)) + +#define REORDER_8x32(t0, t1, t2, t3, t4, t5, t6, t7) { \ + __m512i v; \ + t0 = _mm512_unpacklo_epi32(r0, r1); \ + t1 = _mm512_unpackhi_epi32(r0, r1); \ + t2 = _mm512_unpacklo_epi32(r2, r3); \ + t3 = _mm512_unpackhi_epi32(r2, r3); \ + t4 = _mm512_unpacklo_epi32(r4, r5); \ + t5 = _mm512_unpackhi_epi32(r4, r5); \ + t6 = _mm512_unpacklo_epi32(r6, r7); \ + t7 = _mm512_unpackhi_epi32(r6, r7); \ + _MM512_SHUFFLE_i32(v, t0, t2, 0x4E); \ + r0 = _mm512_mask_blend_epi32(kc, t0, v); \ + r1 = _mm512_mask_blend_epi32(k3, t2, v); \ + _MM512_SHUFFLE_i32(v, t1, t3, 0x4E); \ + r2 = _mm512_mask_blend_epi32(kc, t1, v); \ + r3 = _mm512_mask_blend_epi32(k3, t3, v); \ + _MM512_SHUFFLE_i32(v, t4, t6, 0x4E); \ + r4 = _mm512_mask_blend_epi32(kc, t4, v); \ + r5 = _mm512_mask_blend_epi32(k3, t6, v); \ + _MM512_SHUFFLE_i32(v, t5, t7, 0x4E); \ + r6 = _mm512_mask_blend_epi32(kc, t5, v); \ + r7 = _mm512_mask_blend_epi32(k3, t7, v); \ + t0 = _mm512_permutex2var_epi32(r0, idx_lo, r4); \ + t1 = _mm512_permutex2var_epi32(r1, idx_lo, r5); \ + t2 = _mm512_permutex2var_epi32(r2, idx_lo, r6); \ + t3 = _mm512_permutex2var_epi32(r3, idx_lo, r7); \ + t4 = _mm512_permutex2var_epi32(r0, idx_hi, r4); \ + t5 = _mm512_permutex2var_epi32(r1, idx_hi, r5); \ + t6 = _mm512_permutex2var_epi32(r2, idx_hi, r6); \ + t7 = _mm512_permutex2var_epi32(r3, idx_hi, r7); \ +} + +#define STORE_512_LO(x) \ + v = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \ + _mm512_storeu_si512(boffset0 + x*32, v); + +#define STORE_512_HI(x) \ + v = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \ + _mm512_storeu_si512(boffset0 + (x + 8)*32, v); + +#define MASK_STORE_512_LO(x) \ + v = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \ + _mm512_mask_storeu_epi32(boffset0 + 2*x*remain_n, nmask, v); + +#define MASK_STORE_512_HI(x) \ + v = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \ + _mm512_mask_storeu_epi32(boffset0 + 2*(x + 8)*remain_n, nmask, v); + +#define STORE_512(x, y) {\ + __m512i v; \ + if (x == 0) { STORE_512_LO(y); } \ + else { STORE_512_HI(y); } \ +} + +#define MASK_STORE_512(x, y) {\ + __m512i v; \ + if (x == 0) { MASK_STORE_512_LO(y); } \ + else { MASK_STORE_512_HI(y); } \ +} + +#define SET_TAIL(y, x) {\ + if (y == 0) tail = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \ + else tail = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \ +} + +#define GET_TAIL() \ + switch (n_store + 1) { \ + case 16: SET_TAIL(1, 7); break; \ + case 15: SET_TAIL(1, 6); break; \ + case 14: SET_TAIL(1, 5); break; \ + case 13: SET_TAIL(1, 4); break; \ + case 12: SET_TAIL(1, 3); break; \ + case 11: SET_TAIL(1, 2); break; \ + case 10: SET_TAIL(1, 1); break; \ + case 9: SET_TAIL(1, 0); break; \ + case 8: SET_TAIL(0, 7); break; \ + case 7: SET_TAIL(0, 6); break; \ + case 6: SET_TAIL(0, 5); break; \ + case 5: SET_TAIL(0, 4); break; \ + case 4: SET_TAIL(0, 3); break; \ + case 3: SET_TAIL(0, 2); break; \ + case 2: SET_TAIL(0, 1); break; \ + case 1: SET_TAIL(0, 0); break; \ + } + + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ + BLASLONG i, j; + + IFLOAT *boffset0; + IFLOAT *aoffset; + IFLOAT *aoffset00, *aoffset01, *aoffset02, *aoffset03, *aoffset04, *aoffset05, *aoffset06, *aoffset07; + IFLOAT *aoffset10, *aoffset11, *aoffset12, *aoffset13, *aoffset14, *aoffset15, *aoffset16, *aoffset17; + aoffset = a; + boffset0 = b; + + BLASLONG n16 = n & ~15; + BLASLONG m32 = m & ~31; + + int permute_table[] = { + 0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b, + 0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f, + }; + u_int64_t permute_table2[] = { + 0x00, 0x01, 0x02, 0x03, 8|0x0, 8|0x1, 8|0x2, 8|0x3, + 0x04, 0x05, 0x06, 0x07, 8|0x4, 8|0x5, 8|0x6, 8|0x7, + }; + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 16); + __m512i idx_lo2 = _mm512_loadu_si512(permute_table2); + __m512i idx_hi2 = _mm512_loadu_si512(permute_table2 + 8); + __mmask16 kc = 0xcccc; + __mmask16 k3 = 0x3333; + __m512i r0, r1, r2, r3, r4, r5, r6, r7; + __m512i t00, t01, t02, t03, t04, t05, t06, t07; + __m512i t10, t11, t12, t13, t14, t15, t16, t17; + + for (j = 0; j < n16; j += 16) { + aoffset00 = aoffset; + aoffset01 = aoffset00 + lda; + aoffset02 = aoffset01 + lda; + aoffset03 = aoffset02 + lda; + aoffset04 = aoffset03 + lda; + aoffset05 = aoffset04 + lda; + aoffset06 = aoffset05 + lda; + aoffset07 = aoffset06 + lda; + aoffset10 = aoffset07 + lda; + aoffset11 = aoffset10 + lda; + aoffset12 = aoffset11 + lda; + aoffset13 = aoffset12 + lda; + aoffset14 = aoffset13 + lda; + aoffset15 = aoffset14 + lda; + aoffset16 = aoffset15 + lda; + aoffset17 = aoffset16 + lda; + aoffset += 16 * lda; + for (i = 0; i < m32; i += 32) { + r0 = _mm512_loadu_si512(aoffset00 + i); + r1 = _mm512_loadu_si512(aoffset01 + i); + r2 = _mm512_loadu_si512(aoffset02 + i); + r3 = _mm512_loadu_si512(aoffset03 + i); + r4 = _mm512_loadu_si512(aoffset04 + i); + r5 = _mm512_loadu_si512(aoffset05 + i); + r6 = _mm512_loadu_si512(aoffset06 + i); + r7 = _mm512_loadu_si512(aoffset07 + i); + REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07); + r0 = _mm512_loadu_si512(aoffset10 + i); + r1 = _mm512_loadu_si512(aoffset11 + i); + r2 = _mm512_loadu_si512(aoffset12 + i); + r3 = _mm512_loadu_si512(aoffset13 + i); + r4 = _mm512_loadu_si512(aoffset14 + i); + r5 = _mm512_loadu_si512(aoffset15 + i); + r6 = _mm512_loadu_si512(aoffset16 + i); + r7 = _mm512_loadu_si512(aoffset17 + i); + REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17); + STORE_512(0, 0); STORE_512(0, 1); STORE_512(0, 2); STORE_512(0, 3); + STORE_512(0, 4); STORE_512(0, 5); STORE_512(0, 6); STORE_512(0, 7); + STORE_512(1, 0); STORE_512(1, 1); STORE_512(1, 2); STORE_512(1, 3); + STORE_512(1, 4); STORE_512(1, 5); STORE_512(1, 6); STORE_512(1, 7); + boffset0 += 16 * 32; + } + if (i < m) { + int remain_m = m - i; + __mmask32 mmask = (1UL << remain_m) - 1; + r0 = _mm512_maskz_loadu_epi16(mmask, aoffset00 + i); + r1 = _mm512_maskz_loadu_epi16(mmask, aoffset01 + i); + r2 = _mm512_maskz_loadu_epi16(mmask, aoffset02 + i); + r3 = _mm512_maskz_loadu_epi16(mmask, aoffset03 + i); + r4 = _mm512_maskz_loadu_epi16(mmask, aoffset04 + i); + r5 = _mm512_maskz_loadu_epi16(mmask, aoffset05 + i); + r6 = _mm512_maskz_loadu_epi16(mmask, aoffset06 + i); + r7 = _mm512_maskz_loadu_epi16(mmask, aoffset07 + i); + REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07); + r0 = _mm512_maskz_loadu_epi16(mmask, aoffset10 + i); + r1 = _mm512_maskz_loadu_epi16(mmask, aoffset11 + i); + r2 = _mm512_maskz_loadu_epi16(mmask, aoffset12 + i); + r3 = _mm512_maskz_loadu_epi16(mmask, aoffset13 + i); + r4 = _mm512_maskz_loadu_epi16(mmask, aoffset14 + i); + r5 = _mm512_maskz_loadu_epi16(mmask, aoffset15 + i); + r6 = _mm512_maskz_loadu_epi16(mmask, aoffset16 + i); + r7 = _mm512_maskz_loadu_epi16(mmask, aoffset17 + i); + REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17); + int n_store = remain_m/2; + switch (n_store) { + case 15: STORE_512(1, 6); + case 14: STORE_512(1, 5); + case 13: STORE_512(1, 4); + case 12: STORE_512(1, 3); + case 11: STORE_512(1, 2); + case 10: STORE_512(1, 1); + case 9: STORE_512(1, 0); + case 8: STORE_512(0, 7); + case 7: STORE_512(0, 6); + case 6: STORE_512(0, 5); + case 5: STORE_512(0, 4); + case 4: STORE_512(0, 3); + case 3: STORE_512(0, 2); + case 2: STORE_512(0, 1); + case 1: STORE_512(0, 0); + } + boffset0 += n_store * 32; + if (m & 0x1) { + __m512i tail; + GET_TAIL(); + _mm256_storeu_si256((void *)boffset0, _mm512_cvtepi32_epi16(tail)); + boffset0 += 16; + } + } + + } + if (j < n) { + int remain_n = n - j; + __mmask16 nmask = (1UL << remain_n) - 1; + int load0, load1; + if (remain_n > 8) { + load0 = 8; + load1 = remain_n - 8; + } else { + load0 = remain_n; + load1 = 0; + } + aoffset00 = aoffset; + aoffset01 = aoffset00 + lda; + aoffset02 = aoffset01 + lda; + aoffset03 = aoffset02 + lda; + aoffset04 = aoffset03 + lda; + aoffset05 = aoffset04 + lda; + aoffset06 = aoffset05 + lda; + aoffset07 = aoffset06 + lda; + aoffset10 = aoffset07 + lda; + aoffset11 = aoffset10 + lda; + aoffset12 = aoffset11 + lda; + aoffset13 = aoffset12 + lda; + aoffset14 = aoffset13 + lda; + aoffset15 = aoffset14 + lda; + aoffset16 = aoffset15 + lda; + aoffset17 = aoffset16 + lda; + aoffset += 16 * lda; + for (i = 0; i < m32; i += 32) { + switch (load0) { + case 8: r7 = _mm512_loadu_si512(aoffset07 + i); + case 7: r6 = _mm512_loadu_si512(aoffset06 + i); + case 6: r5 = _mm512_loadu_si512(aoffset05 + i); + case 5: r4 = _mm512_loadu_si512(aoffset04 + i); + case 4: r3 = _mm512_loadu_si512(aoffset03 + i); + case 3: r2 = _mm512_loadu_si512(aoffset02 + i); + case 2: r1 = _mm512_loadu_si512(aoffset01 + i); + case 1: r0 = _mm512_loadu_si512(aoffset00 + i); + } + REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07); + switch (load1) { + case 8: r7 = _mm512_loadu_si512(aoffset17 + i); + case 7: r6 = _mm512_loadu_si512(aoffset16 + i); + case 6: r5 = _mm512_loadu_si512(aoffset15 + i); + case 5: r4 = _mm512_loadu_si512(aoffset14 + i); + case 4: r3 = _mm512_loadu_si512(aoffset13 + i); + case 3: r2 = _mm512_loadu_si512(aoffset12 + i); + case 2: r1 = _mm512_loadu_si512(aoffset11 + i); + case 1: r0 = _mm512_loadu_si512(aoffset10 + i); + } + REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17); + MASK_STORE_512(0, 0); MASK_STORE_512(0, 1); MASK_STORE_512(0, 2); MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); MASK_STORE_512(0, 5); MASK_STORE_512(0, 6); MASK_STORE_512(0, 7); + MASK_STORE_512(1, 0); MASK_STORE_512(1, 1); MASK_STORE_512(1, 2); MASK_STORE_512(1, 3); + MASK_STORE_512(1, 4); MASK_STORE_512(1, 5); MASK_STORE_512(1, 6); MASK_STORE_512(1, 7); + boffset0 += remain_n * 32; + } + if (i < m) { + int remain_m = m - i; + __mmask32 mmask = (1UL << remain_m) - 1; + switch (load0) { + case 8: r7 = _mm512_maskz_loadu_epi16(mmask, aoffset07 + i); + case 7: r6 = _mm512_maskz_loadu_epi16(mmask, aoffset06 + i); + case 6: r5 = _mm512_maskz_loadu_epi16(mmask, aoffset05 + i); + case 5: r4 = _mm512_maskz_loadu_epi16(mmask, aoffset04 + i); + case 4: r3 = _mm512_maskz_loadu_epi16(mmask, aoffset03 + i); + case 3: r2 = _mm512_maskz_loadu_epi16(mmask, aoffset02 + i); + case 2: r1 = _mm512_maskz_loadu_epi16(mmask, aoffset01 + i); + case 1: r0 = _mm512_maskz_loadu_epi16(mmask, aoffset00 + i); + } + REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07); + switch (load1) { + case 8: r7 = _mm512_maskz_loadu_epi16(mmask, aoffset17 + i); + case 7: r6 = _mm512_maskz_loadu_epi16(mmask, aoffset16 + i); + case 6: r5 = _mm512_maskz_loadu_epi16(mmask, aoffset15 + i); + case 5: r4 = _mm512_maskz_loadu_epi16(mmask, aoffset14 + i); + case 4: r3 = _mm512_maskz_loadu_epi16(mmask, aoffset13 + i); + case 3: r2 = _mm512_maskz_loadu_epi16(mmask, aoffset12 + i); + case 2: r1 = _mm512_maskz_loadu_epi16(mmask, aoffset11 + i); + case 1: r0 = _mm512_maskz_loadu_epi16(mmask, aoffset10 + i); + } + REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17); + int n_store = remain_m/2; + switch (n_store) { + case 15: MASK_STORE_512(1, 6); + case 14: MASK_STORE_512(1, 5); + case 13: MASK_STORE_512(1, 4); + case 12: MASK_STORE_512(1, 3); + case 11: MASK_STORE_512(1, 2); + case 10: MASK_STORE_512(1, 1); + case 9: MASK_STORE_512(1, 0); + case 8: MASK_STORE_512(0, 7); + case 7: MASK_STORE_512(0, 6); + case 6: MASK_STORE_512(0, 5); + case 5: MASK_STORE_512(0, 4); + case 4: MASK_STORE_512(0, 3); + case 3: MASK_STORE_512(0, 2); + case 2: MASK_STORE_512(0, 1); + case 1: MASK_STORE_512(0, 0); + } + boffset0 += n_store * remain_n * 2; + if (m & 0x1) { + __m512i tail; + GET_TAIL(); + _mm256_mask_storeu_epi16((void *)boffset0, nmask, _mm512_cvtepi32_epi16(tail)); + } + } + } + return 0; +} diff --git a/kernel/x86_64/sbgemm_ncopy_4_cooperlake.c b/kernel/x86_64/sbgemm_ncopy_4_cooperlake.c new file mode 100644 index 000000000..eefbd7355 --- /dev/null +++ b/kernel/x86_64/sbgemm_ncopy_4_cooperlake.c @@ -0,0 +1,208 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include +#include "common.h" + +#define REORDER_4x32(r0, r1, r2, r3) {\ + __m512i t0, t1, t2, t3; \ + t0 = _mm512_unpacklo_epi32(r0, r1); \ + t1 = _mm512_unpackhi_epi32(r0, r1); \ + t2 = _mm512_unpacklo_epi32(r2, r3); \ + t3 = _mm512_unpackhi_epi32(r2, r3); \ + r0 = _mm512_unpacklo_epi64(t0, t2); \ + r1 = _mm512_unpackhi_epi64(t0, t2); \ + r2 = _mm512_unpacklo_epi64(t1, t3); \ + r3 = _mm512_unpackhi_epi64(t1, t3); \ + t0 = _mm512_permutex2var_epi32(r0, idx_lo_128, r1); \ + t1 = _mm512_permutex2var_epi32(r0, idx_hi_128, r1); \ + t2 = _mm512_permutex2var_epi32(r2, idx_lo_128, r3); \ + t3 = _mm512_permutex2var_epi32(r2, idx_hi_128, r3); \ + r0 = _mm512_permutex2var_epi32(t0, idx_lo_256, t2); \ + r1 = _mm512_permutex2var_epi32(t1, idx_lo_256, t3); \ + r2 = _mm512_permutex2var_epi32(t0, idx_hi_256, t2); \ + r3 = _mm512_permutex2var_epi32(t1, idx_hi_256, t3); \ +} + +#define REORDER_4x8(r0, r1, r2, r3) {\ + __m128i t0, t1, t2, t3; \ + t0 = _mm_unpacklo_epi32(r0, r1); \ + t1 = _mm_unpackhi_epi32(r0, r1); \ + t2 = _mm_unpacklo_epi32(r2, r3); \ + t3 = _mm_unpackhi_epi32(r2, r3); \ + r0 = _mm_unpacklo_epi64(t0, t2); \ + r1 = _mm_unpackhi_epi64(t0, t2); \ + r2 = _mm_unpacklo_epi64(t1, t3); \ + r3 = _mm_unpackhi_epi64(t1, t3); \ +} + +#define GET_TAIL(tail, remain_m) \ + switch((remain_m + 1)/2) { \ + case 1: tail = r0; break; \ + case 2: tail = r1; break; \ + case 3: tail = r2; break; \ + case 4: tail = r3; break; \ + } + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ + BLASLONG i, j; + IFLOAT *aoffset; + IFLOAT *aoffset0, *aoffset1, *aoffset2, *aoffset3; + + IFLOAT *boffset; + + aoffset = a; + boffset = b; + + BLASLONG m32 = m & ~31; + BLASLONG m8 = m & ~7; + BLASLONG n4 = n & ~3; + + int permute_table[] = { + 0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b, + 0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f, + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }; + __m512i idx_lo_128 = _mm512_loadu_si512(permute_table); + __m512i idx_hi_128 = _mm512_loadu_si512(permute_table + 16); + __m512i idx_lo_256 = _mm512_loadu_si512(permute_table + 32); + __m512i idx_hi_256 = _mm512_loadu_si512(permute_table + 48); + + for (j = 0; j < n4; j += 4) { + aoffset0 = aoffset; + aoffset1 = aoffset0 + lda; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset += 4 * lda; + + for (i = 0; i < m32; i += 32) { + __m512i r0, r1, r2, r3; + r0 = _mm512_loadu_si512(aoffset0 + i); + r1 = _mm512_loadu_si512(aoffset1 + i); + r2 = _mm512_loadu_si512(aoffset2 + i); + r3 = _mm512_loadu_si512(aoffset3 + i); + REORDER_4x32(r0, r1, r2, r3); + _mm512_storeu_si512(boffset + 32*0, r0); + _mm512_storeu_si512(boffset + 32*1, r1); + _mm512_storeu_si512(boffset + 32*2, r2); + _mm512_storeu_si512(boffset + 32*3, r3); + boffset += 32 * 4; + } + for (; i < m8; i += 8) { + __m128i r0 = _mm_loadu_si128((void *)(aoffset0 + i)); + __m128i r1 = _mm_loadu_si128((void *)(aoffset1 + i)); + __m128i r2 = _mm_loadu_si128((void *)(aoffset2 + i)); + __m128i r3 = _mm_loadu_si128((void *)(aoffset3 + i)); + REORDER_4x8(r0, r1, r2, r3); + _mm_storeu_si128((void *)(boffset + 8*0), r0); + _mm_storeu_si128((void *)(boffset + 8*1), r1); + _mm_storeu_si128((void *)(boffset + 8*2), r2); + _mm_storeu_si128((void *)(boffset + 8*3), r3); + boffset += 8 * 4; + } + if (i < m) { + int remain_m = m - i; + __mmask8 r_mask = (1UL << remain_m) - 1; + __m128i r0 = _mm_maskz_loadu_epi16(r_mask, aoffset0 + i); + __m128i r1 = _mm_maskz_loadu_epi16(r_mask, aoffset1 + i); + __m128i r2 = _mm_maskz_loadu_epi16(r_mask, aoffset2 + i); + __m128i r3 = _mm_maskz_loadu_epi16(r_mask, aoffset3 + i); + REORDER_4x8(r0, r1, r2, r3); + + // store should skip the tail odd line + int num_store = remain_m/2; + switch(num_store) { + case 3: _mm_storeu_si128((void *)(boffset + 8*2), r2); + case 2: _mm_storeu_si128((void *)(boffset + 8*1), r1); + case 1: _mm_storeu_si128((void *)(boffset + 8*0), r0); + } + boffset += 8 * num_store; + + if (m & 0x1) { // handling the tail + __m128i tail; + GET_TAIL(tail, remain_m); + /* tail vector is fill with zero like: + * a, 0, b, 0, c, 0, d, 0 + * need to extract lo words of data and store + */ + tail = _mm_cvtepi32_epi16(tail); + _mm_store_sd((double *)boffset, (__m128d) tail); // only lower 4 bfloat valid + boffset += 4; + } + } + } + if (j < n) { + int remain_n = n - j; + __mmask8 nmask = (1UL << remain_n) - 1; + aoffset0 = aoffset; + aoffset1 = aoffset0 + lda; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + __m128i r0, r1, r2, r3; + for (i = 0; i < m8; i += 8) { + switch (remain_n) { + case 3: r2 = _mm_loadu_si128((void *)(aoffset2 + i)); + case 2: r1 = _mm_loadu_si128((void *)(aoffset1 + i)); + case 1: r0 = _mm_loadu_si128((void *)(aoffset0 + i)); + } + REORDER_4x8(r0, r1, r2, r3); + _mm_mask_storeu_epi32(boffset + remain_n * 0, nmask, r0); + _mm_mask_storeu_epi32(boffset + remain_n * 2, nmask, r1); + _mm_mask_storeu_epi32(boffset + remain_n * 4, nmask, r2); + _mm_mask_storeu_epi32(boffset + remain_n * 6, nmask, r3); + boffset += 8 * remain_n; + } + if (i < m) { + int remain_m = m - i; + __mmask8 mmask = (1UL << remain_m) - 1; + switch (remain_n) { + case 3: r2 = _mm_maskz_loadu_epi16(mmask, aoffset2 + i); + case 2: r1 = _mm_maskz_loadu_epi16(mmask, aoffset1 + i); + case 1: r0 = _mm_maskz_loadu_epi16(mmask, aoffset0 + i); + } + REORDER_4x8(r0, r1, r2, r3); + + int num_store = remain_m/2; + switch (num_store) { + case 3: _mm_mask_storeu_epi32(boffset + remain_n * 4, nmask, r2); + case 2: _mm_mask_storeu_epi32(boffset + remain_n * 2, nmask, r1); + case 1: _mm_mask_storeu_epi32(boffset + remain_n * 0, nmask, r0); + } + boffset += 2 * num_store * remain_n; + + if (m & 0x1) { + __m128i tail; + GET_TAIL(tail, remain_m); + tail = _mm_cvtepi32_epi16(tail); + _mm_mask_storeu_epi16(boffset, nmask, tail); + } + } + } + return 0; +} diff --git a/kernel/x86_64/sbgemm_small_kernel_nn_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_nn_cooperlake.c new file mode 100644 index 000000000..ec40a5054 --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_nn_cooperlake.c @@ -0,0 +1,2 @@ +#define TRANS_NN +#include "sbgemm_small_kernel_template_cooperlake.c" diff --git a/kernel/x86_64/sbgemm_small_kernel_nt_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_nt_cooperlake.c new file mode 100644 index 000000000..1cdfd2936 --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_nt_cooperlake.c @@ -0,0 +1,2 @@ +#define TRANS_NT +#include "sbgemm_small_kernel_template_cooperlake.c" diff --git a/kernel/x86_64/sbgemm_small_kernel_permit_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_permit_cooperlake.c new file mode 100644 index 000000000..70becd9fa --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_permit_cooperlake.c @@ -0,0 +1,48 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +#include "sbgemm_block_microk_cooperlake.c" +// Define micro kernels for ALPHA not ONE scenarios +#undef ONE_ALPHA +#include "sbgemm_microk_cooperlake_template.c" + +// Define micro kernels for ALPHA as ONE scenarios +#define ONE_ALPHA 1 +#include "sbgemm_microk_cooperlake_template.c" + +int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) +{ + double MNK = (double) M * (double) N * (double) K; + if (MNK > 256.0*256.0*256.0) // disable for big size matrix + return 0; + /* small matrix kernel works well for N = 8, 16, 32 */ + if (N == 8 || N == 16 || N == 32) + return 1; + return 0; +} diff --git a/kernel/x86_64/sbgemm_small_kernel_template_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_template_cooperlake.c new file mode 100644 index 000000000..1ab7a34ab --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_template_cooperlake.c @@ -0,0 +1,96 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" +#include + +extern void sbgemm_scal_operation(BLASLONG M, BLASLONG N, float beta, float *C, BLASLONG ldc); +extern void sbgemm_zero_operation(BLASLONG M, BLASLONG N, float *C, BLASLONG ldc); + +extern void sbgemm_blocking_kernel_nn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_nn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_nt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_nt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_tn_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_tn_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_tt_alpha(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); +extern void sbgemm_blocking_kernel_tt_one(blasint M, blasint N, blasint K, float alpha, bfloat16 *A, blasint lda, bfloat16 *B, blasint ldb, float *C, blasint ldc, bfloat16 * block_A, bfloat16 * block_B); + +#if defined(TRANS_NN) +#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_nn_one +#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_nn_alpha +#elif defined(TRANS_NT) +#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_nt_one +#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_nt_alpha +#elif defined(TRANS_TN) +#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_tn_one +#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_tn_alpha +#elif defined(TRANS_TT) +#define SBGEMM_BLOCKING_KERNEL_ONE sbgemm_blocking_kernel_tt_one +#define SBGEMM_BLOCKING_KERNEL_ALPHA sbgemm_blocking_kernel_tt_alpha +#endif + +#define BF16_BLOCK_THRES_K 1024 +// If we want to adjust this to be bigger, need to change COL_MAJOR_INCOPY_KERNEL_Kx32 kernel to be bigger also +#define BF16_BLOCK_THRES_M 32 +#define BF16_BLOCK_THRES_N 1024 + +#define MALLOC_ALIGN64(ptr, size, raw_ptr) \ + raw_ptr = malloc((size) + 63); \ + ptr = (bfloat16 *)(((uintptr_t) raw_ptr + 63) & ~(uintptr_t)63) + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, IFLOAT * A, BLASLONG lda, FLOAT alpha, IFLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + bfloat16 * block_A; + bfloat16 * block_B; + void* raw_ptrA; + void* raw_ptrB; + + MALLOC_ALIGN64(block_A, sizeof(bfloat16) * BF16_BLOCK_THRES_K * BF16_BLOCK_THRES_M, raw_ptrA); + MALLOC_ALIGN64(block_B, sizeof(bfloat16) * BF16_BLOCK_THRES_N * BF16_BLOCK_THRES_K, raw_ptrB); + +#if defined(B0) + sbgemm_zero_operation(M, N, C, ldc); +#else + sbgemm_scal_operation(M, N, beta, C, ldc); +#endif + + if (alpha == ONE) { + SBGEMM_BLOCKING_KERNEL_ONE(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } else { + SBGEMM_BLOCKING_KERNEL_ALPHA(M, N, K, alpha, A, lda, B, ldb, C, ldc, block_A, block_B); + } + + free(raw_ptrA); + free(raw_ptrB); + return 0; +} diff --git a/kernel/x86_64/sbgemm_small_kernel_tn_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_tn_cooperlake.c new file mode 100644 index 000000000..f1a0d0d0c --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_tn_cooperlake.c @@ -0,0 +1,2 @@ +#define TRANS_TN +#include "sbgemm_small_kernel_template_cooperlake.c" diff --git a/kernel/x86_64/sbgemm_small_kernel_tt_cooperlake.c b/kernel/x86_64/sbgemm_small_kernel_tt_cooperlake.c new file mode 100644 index 000000000..8a2a597bc --- /dev/null +++ b/kernel/x86_64/sbgemm_small_kernel_tt_cooperlake.c @@ -0,0 +1,2 @@ +#define TRANS_TT +#include "sbgemm_small_kernel_template_cooperlake.c" diff --git a/kernel/x86_64/sbgemm_tcopy_16_cooperlake.c b/kernel/x86_64/sbgemm_tcopy_16_cooperlake.c new file mode 100644 index 000000000..88725f343 --- /dev/null +++ b/kernel/x86_64/sbgemm_tcopy_16_cooperlake.c @@ -0,0 +1,164 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include +#include "common.h" + + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ + BLASLONG i, j; + + IFLOAT *boffset0, *boffset1; + + boffset0 = b; + + BLASLONG n32 = n & ~31; + BLASLONG m4 = m & ~3; + BLASLONG m2 = m & ~1; + + uint32_t permute_table[] = { + 0x00, 0x01, 0x02, 0x03, 0x10, 0x11, 0x12, 0x13, 0x04, 0x05, 0x06, 0x07, 0x14, 0x15, 0x16, 0x17, + 0x08, 0x09, 0x0a, 0x0b, 0x18, 0x19, 0x1a, 0x1b, 0x0c, 0x0d, 0x0e, 0x0f, 0x1c, 0x1d, 0x1e, 0x1f, + }; + + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 16); + + for (j = 0; j < n32; j += 32) { + /* process 2x16 n at the same time */ + boffset1 = boffset0 + m * 16; + for (i = 0; i < m4; i += 4) { + /* bf16 fma need special memory layout: + * for memory layout like below: + * a00, a01, a02, a03, a04, a05 .... + * a10, a11, a12, a13, a14, a15 .... + * need to copy as: + * a00, a10, a01, a11, a02, a12, a03, a13, ... + */ + __m512i a0 = _mm512_loadu_si512(&a[(i + 0)*lda + j]); + __m512i a1 = _mm512_loadu_si512(&a[(i + 1)*lda + j]); + __m512i a2 = _mm512_loadu_si512(&a[(i + 2)*lda + j]); + __m512i a3 = _mm512_loadu_si512(&a[(i + 3)*lda + j]); + + __m512i a00 = _mm512_unpacklo_epi16(a0, a1); + __m512i a01 = _mm512_unpackhi_epi16(a0, a1); + __m512i a10 = _mm512_unpacklo_epi16(a2, a3); + __m512i a11 = _mm512_unpackhi_epi16(a2, a3); + + a0 = _mm512_permutex2var_epi32(a00, idx_lo, a01); + a1 = _mm512_permutex2var_epi32(a00, idx_hi, a01); + a2 = _mm512_permutex2var_epi32(a10, idx_lo, a11); + a3 = _mm512_permutex2var_epi32(a10, idx_hi, a11); + + _mm512_storeu_si512(boffset0, a0); + _mm512_storeu_si512(boffset1, a1); + _mm512_storeu_si512(boffset0 + 32, a2); + _mm512_storeu_si512(boffset1 + 32, a3); + boffset0 += 64; + boffset1 += 64; + } + for (; i < m2; i += 2) { + __m512i a0 = _mm512_loadu_si512(&a[(i + 0)*lda + j]); + __m512i a1 = _mm512_loadu_si512(&a[(i + 1)*lda + j]); + + __m512i a00 = _mm512_unpacklo_epi16(a0, a1); + __m512i a01 = _mm512_unpackhi_epi16(a0, a1); + + a0 = _mm512_permutex2var_epi32(a00, idx_lo, a01); + a1 = _mm512_permutex2var_epi32(a00, idx_hi, a01); + + _mm512_storeu_si512(boffset0, a0); + _mm512_storeu_si512(boffset1, a1); + boffset0 += 32; + boffset1 += 32; + } + for (; i < m; i++) { + /* just copy the only remains row */ + __m256i a0 = _mm256_loadu_si256((void *)&a[(i + 0)*lda + j]); + __m256i a1 = _mm256_loadu_si256((void *)&a[(i + 0)*lda + j + 16]); + _mm256_storeu_si256((void *)boffset0, a0); + _mm256_storeu_si256((void *)boffset1, a1); + boffset0 += 16; + boffset1 += 16; + } + boffset0 = boffset1; + } + if (j < n) { + uint32_t remains = n - j; + __mmask32 r_mask = (1UL << remains) - 1; + if (remains > 16) { + boffset1 = boffset0 + m * 16; + uint32_t tail1 = remains - 16; + __mmask16 w_mask1 = (1UL << tail1) - 1; + for (i = 0; i < m2; i += 2) { + __m512i a0 = _mm512_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + __m512i a1 = _mm512_maskz_loadu_epi16(r_mask, &a[(i + 1)*lda + j]); + + __m512i a00 = _mm512_unpacklo_epi16(a0, a1); + __m512i a01 = _mm512_unpackhi_epi16(a0, a1); + + a0 = _mm512_permutex2var_epi32(a00, idx_lo, a01); + a1 = _mm512_permutex2var_epi32(a00, idx_hi, a01); + + _mm512_storeu_si512(boffset0, a0); + _mm512_mask_storeu_epi32(boffset1, w_mask1, a1); + + boffset0 += 32; + boffset1 += 2 * tail1; + } + for (; i < m; i++) { + __m256i a0 = _mm256_loadu_si256((void *)&a[(i + 0)*lda + j]); + __m256i a1 = _mm256_maskz_loadu_epi16(w_mask1, (void *)&a[(i + 0)*lda + j + 16]); + _mm256_storeu_si256((void *)boffset0, a0); + _mm256_mask_storeu_epi16((void *)boffset1, w_mask1, a1); + boffset0 += 16; + boffset1 += tail1; + } + } else { + __mmask16 w_mask = (1UL << remains ) - 1; + for (i = 0; i < m2; i += 2) { + __m512i a0 = _mm512_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + __m512i a1 = _mm512_maskz_loadu_epi16(r_mask, &a[(i + 1)*lda + j]); + + __m512i a00 = _mm512_unpacklo_epi16(a0, a1); + __m512i a01 = _mm512_unpackhi_epi16(a0, a1); + + a0 = _mm512_permutex2var_epi32(a00, idx_lo, a01); + + _mm512_mask_storeu_epi32(boffset0, w_mask, a0); + boffset0 += 2 * remains; + } + for (; i < m; i++) { + __m256i a0 = _mm256_maskz_loadu_epi16(w_mask, &a[(i + 0)*lda + j]); + _mm256_mask_storeu_epi16(boffset0, w_mask, a0); + boffset0 += remains; + } + } + } + return 0; +} diff --git a/kernel/x86_64/sbgemm_tcopy_4_cooperlake.c b/kernel/x86_64/sbgemm_tcopy_4_cooperlake.c new file mode 100644 index 000000000..e9edd4571 --- /dev/null +++ b/kernel/x86_64/sbgemm_tcopy_4_cooperlake.c @@ -0,0 +1,216 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include +#include "common.h" + +#define STORE_VEC(Bx, By, vec) \ + if (By == 0) asm("vmovdqu16 %0, (%1)": : "v"(vec), "r"(boffset##Bx)); \ + else asm("vmovdqu16 %0, (%1, %2, %c3)": : "v"(vec), "r"(boffset##Bx), "r"(blk_size), "n"(By * 2)); + +int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){ + BLASLONG i, j; + + IFLOAT *boffset0, *boffset1; + + boffset0 = b; + + BLASLONG n24 = n - (n % 24); + BLASLONG n8 = n & ~7; + BLASLONG m8 = m & ~7; + BLASLONG m4 = m & ~3; + BLASLONG m2 = m & ~1; + + int permute_table[] = { + 0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b, + 0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f, + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }; + + j = 0; + if (n > 23) { + /* n = 24 is the max width in current blocking setting */ + __m512i idx_lo_128 = _mm512_loadu_si512(permute_table); + __m512i idx_hi_128 = _mm512_loadu_si512(permute_table + 16); + __m512i idx_lo_256 = _mm512_loadu_si512(permute_table + 32); + __m512i idx_hi_256 = _mm512_loadu_si512(permute_table + 48); + __mmask32 mask24 = (1UL << 24) - 1; + BLASLONG blk_size = m * 4; + BLASLONG stride = blk_size * 3; + + for (; j < n24; j += 24) { + boffset1 = boffset0 + stride; + for (i = 0; i < m8; i += 8) { + __m512i r0, r1, r2, r3, r4, r5, r6, r7; + __m512i t0, t1, t2, t3, t4, t5, t6, t7; + r0 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 0)*lda + j]); + r1 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 1)*lda + j]); + r2 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 2)*lda + j]); + r3 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 3)*lda + j]); + r4 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 4)*lda + j]); + r5 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 5)*lda + j]); + r6 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 6)*lda + j]); + r7 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 7)*lda + j]); + + t0 = _mm512_unpacklo_epi16(r0, r1); + t1 = _mm512_unpackhi_epi16(r0, r1); + t2 = _mm512_unpacklo_epi16(r2, r3); + t3 = _mm512_unpackhi_epi16(r2, r3); + t4 = _mm512_unpacklo_epi16(r4, r5); + t5 = _mm512_unpackhi_epi16(r4, r5); + t6 = _mm512_unpacklo_epi16(r6, r7); + t7 = _mm512_unpackhi_epi16(r6, r7); + + r0 = _mm512_permutex2var_epi32(t0, idx_lo_128, t2); + r1 = _mm512_permutex2var_epi32(t1, idx_lo_128, t3); + r2 = _mm512_permutex2var_epi32(t4, idx_lo_128, t6); + r3 = _mm512_permutex2var_epi32(t5, idx_lo_128, t7); + r4 = _mm512_permutex2var_epi32(t0, idx_hi_128, t2); + r5 = _mm512_permutex2var_epi32(t1, idx_hi_128, t3); + r6 = _mm512_permutex2var_epi32(t4, idx_hi_128, t6); + r7 = _mm512_permutex2var_epi32(t5, idx_hi_128, t7); + + t0 = _mm512_permutex2var_epi32(r0, idx_lo_256, r2); + t1 = _mm512_permutex2var_epi32(r1, idx_lo_256, r3); + t2 = _mm512_permutex2var_epi32(r4, idx_lo_256, r6); + t3 = _mm512_permutex2var_epi32(r5, idx_lo_256, r7); + t4 = _mm512_permutex2var_epi32(r0, idx_hi_256, r2); + t5 = _mm512_permutex2var_epi32(r1, idx_hi_256, r3); + + STORE_VEC(0, 0, t0); STORE_VEC(0, 1, t1); STORE_VEC(0, 2, t2); + STORE_VEC(1, 0, t3); STORE_VEC(1, 1, t4); STORE_VEC(1, 2, t5); + boffset0 += 32; + boffset1 += 32; + } + for (; i < m2; i += 2) { + __m512i r0, r1, t0, t1; + r0 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 0)*lda + j]); + r1 = _mm512_maskz_loadu_epi16(mask24, &a[(i + 1)*lda + j]); + t0 = _mm512_unpacklo_epi16(r0, r1); + t1 = _mm512_unpackhi_epi16(r0, r1); + STORE_VEC(0, 0, _mm512_extracti32x4_epi32(t0, 0)); + STORE_VEC(0, 1, _mm512_extracti32x4_epi32(t1, 0)); + STORE_VEC(0, 2, _mm512_extracti32x4_epi32(t0, 1)); + STORE_VEC(1, 0, _mm512_extracti32x4_epi32(t1, 1)); + STORE_VEC(1, 1, _mm512_extracti32x4_epi32(t0, 2)); + STORE_VEC(1, 2, _mm512_extracti32x4_epi32(t1, 2)); + boffset0 += 8; + boffset1 += 8; + } + for (; i < m; i++) { + *(uint64_t *)(boffset0 + blk_size * 0) = *(uint64_t *)&a[i * lda + j + 0]; + *(uint64_t *)(boffset0 + blk_size * 1) = *(uint64_t *)&a[i * lda + j + 4]; + *(uint64_t *)(boffset0 + blk_size * 2) = *(uint64_t *)&a[i * lda + j + 8]; + *(uint64_t *)(boffset1 + blk_size * 0) = *(uint64_t *)&a[i * lda + j + 12]; + *(uint64_t *)(boffset1 + blk_size * 1) = *(uint64_t *)&a[i * lda + j + 16]; + *(uint64_t *)(boffset1 + blk_size * 2) = *(uint64_t *)&a[i * lda + j + 20]; + boffset0 += 4; + boffset1 += 4; + } + boffset0 += stride * 2; + } + } + + for (; j < n8; j += 8) { + boffset1 = boffset0 + m * 4; + for (i = 0; i < m4; i += 4) { + __m128i a0 = _mm_loadu_si128((void *)&a[(i + 0)*lda + j]); + __m128i a1 = _mm_loadu_si128((void *)&a[(i + 1)*lda + j]); + __m128i a2 = _mm_loadu_si128((void *)&a[(i + 2)*lda + j]); + __m128i a3 = _mm_loadu_si128((void *)&a[(i + 3)*lda + j]); + __m128i a00 = _mm_unpacklo_epi16(a0, a1); + __m128i a01 = _mm_unpackhi_epi16(a0, a1); + __m128i a10 = _mm_unpacklo_epi16(a2, a3); + __m128i a11 = _mm_unpackhi_epi16(a2, a3); + _mm_storeu_si128((void *)(boffset0 + 0), a00); + _mm_storeu_si128((void *)(boffset0 + 8), a10); + _mm_storeu_si128((void *)(boffset1 + 0), a01); + _mm_storeu_si128((void *)(boffset1 + 8), a11); + boffset0 += 16; + boffset1 += 16; + } + for (; i < m2; i+= 2) { + __m128i a0 = _mm_loadu_si128((void *)&a[(i + 0)*lda + j]); + __m128i a1 = _mm_loadu_si128((void *)&a[(i + 1)*lda + j]); + __m128i a00 = _mm_unpacklo_epi16(a0, a1); + __m128i a01 = _mm_unpackhi_epi16(a0, a1); + _mm_storeu_si128((void *)(boffset0 + 0), a00); + _mm_storeu_si128((void *)(boffset1 + 0), a01); + boffset0 += 8; + boffset1 += 8; + } + for (; i < m; i++) { + __m128d a0 = _mm_loadu_pd((void *)&a[(i + 0)*lda + j]); + _mm_store_sd((void *)boffset0, a0); + _mm_store_sd((void *)boffset1, _mm_permute_pd(a0, 0x1)); + boffset0 += 4; + boffset1 += 4; + } + boffset0 = boffset1; + } + if (j < n) { + uint32_t remains = n - j; + __mmask8 r_mask = (1UL << remains) - 1; + if (remains > 4) { + boffset1 = boffset0 + m * 4; + uint32_t tail1 = remains - 4; + __mmask8 w_mask1 = (1UL << tail1) - 1; + for (i = 0; i < m2; i += 2) { + __m128i a0 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + __m128i a1 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 1)*lda + j]); + __m128i a00 = _mm_unpacklo_epi16(a0, a1); + __m128i a01 = _mm_unpackhi_epi16(a0, a1); + _mm_storeu_si128((void *)boffset0, a00); + _mm_mask_storeu_epi32((void *)boffset1, w_mask1, a01); + boffset0 += 8; + boffset1 += 2 * tail1; + } + for (; i < m; i++) { + __m128i a0 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + _mm_store_sd((void *)boffset0, (__m128d) a0); + _mm_mask_storeu_epi16((void *)boffset1, w_mask1, (__m128i) _mm_permute_pd((__m128d) a0, 0x1)); + boffset0 += 4; + boffset1 += tail1; + } + } else { + for (i = 0; i < m2; i += 2) { + __m128i a0 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + __m128i a1 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 1)*lda + j]); + __m128i a00 = _mm_unpacklo_epi16(a0, a1); + _mm_mask_storeu_epi32((void *)boffset0, r_mask, a00); + boffset0 += 2 * remains; + } + for (; i < m; i++) { + __m128i a0 = _mm_maskz_loadu_epi16(r_mask, &a[(i + 0)*lda + j]); + _mm_mask_storeu_epi16((void *)boffset0, r_mask, a0); + } + } + } + return 0; +} diff --git a/kernel/x86_64/sbgemv_n_microk_cooperlake_template.c b/kernel/x86_64/sbgemv_n_microk_cooperlake_template.c index 46e6d0ff9..4711e9720 100644 --- a/kernel/x86_64/sbgemv_n_microk_cooperlake_template.c +++ b/kernel/x86_64/sbgemv_n_microk_cooperlake_template.c @@ -30,6 +30,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Include common macros for BF16 based operations with IA intrinsics #include "bf16_common_macros.h" +#undef STORE16_COMPLETE_RESULT +#undef STORE16_MASK_COMPLETE_RESULT +#undef STORE8_COMPLETE_RESULT +#undef STORE8_MASK_COMPLETE_RESULT +#undef STORE4_COMPLETE_RESULT +#undef STORE4_MASK_COMPLETE_RESULT + #ifndef ZERO_BETA // Beta is non-zero #ifndef ONE_BETA // BETA is not ONE @@ -103,7 +110,9 @@ static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i matrixArray_seed_0, matrixArray_seed_1, matrixArray_seed_2, matrixArray_seed_3; @@ -202,7 +211,7 @@ static int sbgemv_kernel_32xN_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf unsigned int tail_mask_value = (((unsigned int)0xffffffff) >> (32-(m&31))); __mmask32 tail_mask = *((__mmask32*) &tail_mask_value); - unsigned short store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15))); + unsigned int store_tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15))); __mmask32 store_tail_mask = *((__mmask32*) &store_tail_mask_value); accum512_0 = _mm512_setzero_ps(); diff --git a/kernel/x86_64/sbgemv_t_microk_cooperlake_template.c b/kernel/x86_64/sbgemv_t_microk_cooperlake_template.c index 51e681add..8a3a022fb 100644 --- a/kernel/x86_64/sbgemv_t_microk_cooperlake_template.c +++ b/kernel/x86_64/sbgemv_t_microk_cooperlake_template.c @@ -29,6 +29,13 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // Include common macros for BF16 based operations with IA intrinsics #include "bf16_common_macros.h" +#undef STORE16_COMPLETE_RESULT +#undef STORE16_MASK_COMPLETE_RESULT +#undef STORE8_COMPLETE_RESULT +#undef STORE8_MASK_COMPLETE_RESULT +#undef STORE4_COMPLETE_RESULT +#undef STORE4_MASK_COMPLETE_RESULT + #ifndef ZERO_BETA // Beta is non-zero #ifndef ONE_BETA // BETA is not ONE @@ -231,7 +238,9 @@ static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif unsigned char load_mask_value = (((unsigned char)0xff) >> 6); @@ -280,7 +289,7 @@ static int sbgemv_kernel_32x2(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, } else if (tail_num == 8) { __m256 result256 = _mm256_setzero_ps(); - __m256i matrixArray256 = _mm256_loadu_si256(&a[(tag_m_32x)*2]); // Load 8 rows with n=2 + __m256i matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*2]); // Load 8 rows with n=2 __m256i xArray256 = _mm512_castsi512_si256(xArray); result256 = _mm256_dpbf16_ps(result256, (__m256bh) matrixArray256, (__m256bh) xArray256); @@ -323,7 +332,9 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5); @@ -395,9 +406,9 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, result256_0 = _mm256_setzero_ps(); result256_1 = _mm256_setzero_ps(); - matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element - matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element - matrixArray256_2 = _mm256_loadu_si256(&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element + matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element + matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element + matrixArray256_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+10)*3 + 2)]); // Load 5 rows with n=3 plus 1 element matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row matrixArray256_4 = _mm256_permutex2var_epi16(matrixArray256_1, load256_idx01_2nd, matrixArray256_2); // Select the first 2 elements for each row @@ -423,8 +434,8 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, if (tail_num > 10) { unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-10-1)*3+1))); __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); - matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element - matrixArray256_1 = _mm256_loadu_si256(&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element + matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element + matrixArray256_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_32x+5)*3 + 1)]); // Load 5 rows with n=3 plus 1 element matrixArray256_2 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+10)*3 + 2)]); // Load m-tag_m_32x-10 rows matrixArray256_3 = _mm256_permutex2var_epi16(matrixArray256_0, load256_idx01_1st, matrixArray256_1); // Select the first 2 elements for each row @@ -439,7 +450,7 @@ static int sbgemv_kernel_32x3(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, } else if (tail_num > 5) { unsigned short tail_mask_value = (((unsigned short)0xffff) >> (16-((tail_num-5-1)*3+2))); __mmask16 tail_mask = *((__mmask16*) &tail_mask_value); - matrixArray256_0 = _mm256_loadu_si256(&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element + matrixArray256_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_32x)*3]); // Load 5 rows with n=3 plus 1 element matrixArray256_1 = _mm256_maskz_loadu_epi16(tail_mask, &a[((tag_m_32x+5)*3+1)]); // Load m-tag_m_32x-5 rows matrixArray256_2 = _mm256_setzero_si256(); @@ -499,7 +510,9 @@ static int sbgemv_kernel_16x4(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_1 = _mm512_set1_epi32(1); @@ -591,7 +604,9 @@ static int sbgemv_kernel_30x5(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512 result_0, result_1; @@ -782,7 +797,9 @@ static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_1 = _mm512_set1_epi32(1); @@ -866,9 +883,9 @@ static int sbgemv_kernel_16x6(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, result256_0 = _mm256_setzero_ps(); - matrixArray_0 = _mm256_loadu_si256(&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element - matrixArray_1 = _mm256_loadu_si256(&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element - matrixArray_2 = _mm256_loadu_si256(&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element + matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(tag_m_16x)*6]); // Load 2 rows with n=6 plus 4 element + matrixArray_1 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+2)*6 + 4)]); // Load 2 rows with n=6 plus 4 element + matrixArray_2 = _mm256_loadu_si256((__m256i *)&a[((tag_m_16x+5)*6 + 2)]); // Load 2 rows with n=6 plus 4 element // Process the 0|1 elements // Select the 0|1 elements for each row @@ -957,7 +974,9 @@ static int sbgemv_kernel_16x7(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_2 = _mm512_set1_epi32(2); @@ -1110,7 +1129,7 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, { BLASLONG tag_m_16x = m & (~15); - __m128i x128 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7| + __m128i x128 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7| if (tag_m_16x > 0) { __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3; @@ -1122,7 +1141,9 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_2 = _mm512_set1_epi32(2); @@ -1214,7 +1235,7 @@ static int sbgemv_kernel_16x8(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m128 result128, tmp128; for (BLASLONG i = tag_m_16x; i < m; i++) { result128 = _mm_setzero_ps(); - matrixArray128 = _mm_loadu_si128(&a[(i)*8]); // Load 1 rows with n=8 + matrixArray128 = _mm_loadu_si128((__m128i *)&a[(i)*8]); // Load 1 rows with n=8 result128 = _mm_dpbf16_ps(result128, (__m128bh) matrixArray128, (__m128bh) x128); tmp128 = _mm_shuffle_ps(result128, result128, 14); result128 = _mm_add_ps(result128, tmp128); @@ -1258,7 +1279,7 @@ static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, unsigned char x_load_mask_value = (((unsigned char)0xff) >> 7); __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); - __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7| + __m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7| __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|0 |0 | 0| 0| 0| 0| 0| if (tag_m_14x > 0) { @@ -1271,7 +1292,9 @@ static int sbgemv_kernel_14x9(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x, __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m256i M256_EPI16_2 = _mm256_set1_epi16(2); @@ -1390,7 +1413,7 @@ static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x unsigned char x_load_mask_value = (((unsigned char)0xf) >> 3); __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); - __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1|x2|x3|x4|x5|x6|x7| + __m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7| __m128i x128_1 = _mm_maskz_loadu_epi32(x_load_mask, (x+8)); // |x8|x9|0 | 0| 0| 0| 0| 0| if (tag_m_12x > 0) { @@ -1403,7 +1426,9 @@ static int sbgemv_kernel_12x10(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m256i M256_EPI32_1 = _mm256_set1_epi32(1); @@ -1522,7 +1547,7 @@ static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x unsigned char x_load_mask_value = (((unsigned char)0xff) >> 5); __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); - __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2|x3|x4|x5|x6|x7| + __m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2|x3|x4|x5|x6|x7| __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10| 0| 0| 0| 0| 0| if (tag_m_15x > 0) { @@ -1535,7 +1560,9 @@ static int sbgemv_kernel_15x11(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5; @@ -1690,7 +1717,7 @@ static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x unsigned char x_load_mask_value = (((unsigned char)0xff) >> 4); __mmask8 x_load_mask = *((__mmask8*) &x_load_mask_value); - __m128i x128_0 = _mm_loadu_si128(x); // |x0|x1| x2| x3|x4|x5|x6|x7| + __m128i x128_0 = _mm_loadu_si128((__m128i *)x); // |x0|x1| x2| x3|x4|x5|x6|x7| __m128i x128_1 = _mm_maskz_loadu_epi16(x_load_mask, (x+8)); // |x8|x9|x10|x11| 0| 0| 0| 0| if (tag_m_15x > 0) { @@ -1703,7 +1730,9 @@ static int sbgemv_kernel_15x12(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i idx_stage1_base_0, idx_stage1_base_1, idx_stage1_base_2, idx_stage1_base_3, idx_stage1_base_4, idx_stage1_base_5; @@ -1873,16 +1902,15 @@ static int sbgemv_kernel_16x13(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_4 = _mm512_set1_epi32(4); __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0); __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4); - unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 6); - __mmask32 load_mask = *((__mmask32*) &load_mask_value); - // Prepare X with 2-step interleave way xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1); BF16_INTERLEAVE_1x32(xArray) @@ -2045,7 +2073,9 @@ static int sbgemv_kernel_16x14(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_4 = _mm512_set1_epi32(4); @@ -2207,16 +2237,15 @@ static int sbgemv_kernel_16x15(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_4 = _mm512_set1_epi32(4); __m512i idx_base_0 = _mm512_set_epi32(27, 26, 25, 24, 11, 10, 9, 8, 19, 18, 17, 16, 3, 2, 1, 0); __m512i idx_base_1 = _mm512_add_epi32(idx_base_0, M512_EPI32_4); - unsigned int load_mask_value = (((unsigned int)0xffffffff) >> 2); - __mmask32 load_mask = *((__mmask32*) &load_mask_value); - // Prepare X with 2-step interleave way xArray_0 = _mm512_inserti32x8(_mm512_castsi256_si512(x256), x256, 0x1); BF16_INTERLEAVE_1x32(xArray) @@ -2364,7 +2393,7 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x { BLASLONG tag_m_16x = m & (~15); - __m256i x256 = _mm256_loadu_si256(x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15| + __m256i x256 = _mm256_loadu_si256((__m256i *)x); // |x0|x1|x2|x3|x4|x5|x6|x7|x8|x9|x10|x11|x12|x13|x14|x15| if (tag_m_16x > 0) { __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \ @@ -2377,7 +2406,9 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i M512_EPI32_4 = _mm512_set1_epi32(4); @@ -2484,7 +2515,7 @@ static int sbgemv_kernel_16x16(BLASLONG m, float alpha, bfloat16 *a, bfloat16 *x __m128 accum128, tmp128; for (BLASLONG i = tag_m_16x; i < m; i++) { accum256 = _mm256_setzero_ps(); - matrixArray256 = _mm256_loadu_si256(&a[(i)*16]); // Load 1 rows with n=16 + matrixArray256 = _mm256_loadu_si256((__m256i *)&a[(i)*16]); // Load 1 rows with n=16 accum256 = _mm256_dpbf16_ps(accum256, (__m256bh) matrixArray256, (__m256bh) x256); accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256), _mm256_extractf32x4_ps(accum256, 1)); tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e); @@ -2535,7 +2566,9 @@ static int sbgemv_kernel_8x16p_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7, \ @@ -2647,8 +2680,6 @@ static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, b BLASLONG tag_n_32x = n & (~31); BLASLONG tag_n_128x = n & (~127); - __m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7, \ - accum512_8, accum512_9, accum512_10, accum512_11, accum512_12, accum512_13, accum512_14, accum512_15; __m512 accum512_bridge[8]; __m512 accum512_t_0, accum512_t_1, accum512_t_2, accum512_t_3; __m256 accum256_0; @@ -2658,7 +2689,9 @@ static int sbgemv_kernel_1x128_lda_direct(BLASLONG m, BLASLONG n, float alpha, b __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3; @@ -2825,7 +2858,9 @@ static int sbgemv_kernel_8x32_lda_direct(BLASLONG m, BLASLONG n, float alpha, bf __m512 ALPHAVECTOR = _mm512_set1_ps(alpha); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_set1_ps(beta); +#endif #endif __m512i matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7; @@ -2961,7 +2996,9 @@ static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 __m512 ALPHAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(alpha)); #endif #ifndef ZERO_BETA +#ifndef ONE_BETA __m512 BETAVECTOR = _mm512_castps256_ps512(_mm256_set1_ps(beta)); +#endif #endif __m256 accum256_0, accum256_1, accum256_2, accum256_3, accum256_4, accum256_5, accum256_6, accum256_7, \ @@ -3012,7 +3049,7 @@ static int sbgemv_kernel_8x16m_lda(BLASLONG m, BLASLONG n, float alpha, bfloat16 __m128 accum128, tmp128; for (BLASLONG i = tag_m_8x; i < m; i++) { accum256_0 = _mm256_setzero_ps(); - matrixArray_0 = _mm256_loadu_si256(&a[(i)*lda]); // Load 1 rows with n=16 + matrixArray_0 = _mm256_loadu_si256((__m256i *)&a[(i)*lda]); // Load 1 rows with n=16 accum256_0 = _mm256_dpbf16_ps(accum256_0, (__m256bh) matrixArray_0, (__m256bh) xArray256); accum128 = _mm_add_ps(_mm256_castps256_ps128(accum256_0), _mm256_extractf32x4_ps(accum256_0, 1)); tmp128 = _mm_shuffle_ps(accum128, accum128, 0x0e); diff --git a/kernel/x86_64/sgemm_beta_skylakex.c b/kernel/x86_64/sgemm_beta_skylakex.c index 1c29c1168..6217acf48 100644 --- a/kernel/x86_64/sgemm_beta_skylakex.c +++ b/kernel/x86_64/sgemm_beta_skylakex.c @@ -41,7 +41,7 @@ #include int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT beta, - FLOAT *dummy2, BLASLONG dummy3, FLOAT *dummy4, BLASLONG dummy5, + IFLOAT *dummy2, BLASLONG dummy3, IFLOAT *dummy4, BLASLONG dummy5, FLOAT *c, BLASLONG ldc){ BLASLONG i, j; diff --git a/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c new file mode 100644 index 000000000..9bc7a7c58 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_nn_skylakex.c @@ -0,0 +1,612 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() +#define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) +#define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)]) +#define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[k + ldb * (j+N)])) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) +#if defined(B0) +#define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) +#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) +#else +#define STORE_512(M, N) \ + result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + asm("vfmadd231ps (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512)); \ + _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) +#define MASK_STORE_512(M, N) \ + result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \ + _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) +#endif + +#define LOAD_KA_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&mbuf[(mi + M)*K + k]); +#define LOAD_KB_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[(j + N)*ldb + k]) +#define MASK_LOAD_KA_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &mbuf[(mi + M)*K + k]) +#define MASK_LOAD_KB_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[(j + N)*ldb + k]) +#define REDUCE_4(rr0, rr1, rr2, rr3) \ + __m512 r0, r1, r2, r3, t0, t1, t2, t3;\ + r0 = _mm512_unpacklo_ps(rr0, rr1); r1 = _mm512_unpackhi_ps(rr0, rr1); \ + r2 = _mm512_unpacklo_ps(rr2, rr3); r3 = _mm512_unpackhi_ps(rr2, rr3); \ + t0 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0)); t1 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2)); \ + t2 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0)); t3 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2)); \ + r0 = _mm512_add_ps(t0, t1); r1 = _mm512_add_ps(t2, t3); t0 = _mm512_add_ps(r0, r1); \ + __m128 s0, s1, s2, s3; \ + s0 = _mm512_extractf32x4_ps(t0, 0); s1 = _mm512_extractf32x4_ps(t0, 1); s2 = _mm512_extractf32x4_ps(t0, 2); s3 = _mm512_extractf32x4_ps(t0, 3); \ + s0 = _mm_maskz_add_ps(mask8, s0, s1); s2 = _mm_maskz_add_ps(mask8, s2, s3); s0 = _mm_maskz_add_ps(mask8, s0, s2); \ + s0 = _mm_maskz_mul_ps(mask8, alpha_128, s0); +#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) +#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) +#if defined(B0) +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N); +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + _mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); \ +} +#else +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) + beta * C[(j+N)*ldc + i + M]; +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + asm("vfmadd231ps (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_128)); \ + _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + s1 = _mm_i32gather_ps(&C[j*ldc + i + M], vindex_n, 4); \ + s0 = _mm_fmadd_ps(s1, beta_128, s0); \ + _mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); \ +} +#endif + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m64 = M & ~63; + BLASLONG m32 = M & ~31; + BLASLONG m16 = M & ~15; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n6 = N - (N % 6); + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + + __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); +#if !defined(B0) + __m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); +#endif + + for (i = 0; i < m64; i += 64) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); + STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + } + } + for (; i < m32; i += 32) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + MATMUL_512(0, 4); MATMUL_512(1, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + STORE_512(0, 4); STORE_512(1, 4); + STORE_512(0, 5); STORE_512(1, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_512(0, 0); STORE_512(1, 0); + } + } + for (; i < m16; i += 16) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + STORE_512(0, 4); + STORE_512(0, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_512(0, 0); + STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + STORE_512(0, 0); + } + } + int mm = M - i; + if (!mm) return 0; + if (mm > 8 || K < 32) { + register __mmask16 mask asm("k1") = (1UL << mm) - 1; + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); + MASK_STORE_512(0, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_STORE_512(0, 0); + } + } else { + /* M => [1, 8] + * + * This kernel use dot-like style to calc a value - C(x, y): + * C(x, y) = A(x, 0)*B(0, y) + A(x, 1)*B(1, y) +....+ A(x, K)*B(K, y) + * + * Alloc a buf to copy rest of A as row major, + * so memory access from 0 to K is continuous for both A & B. + * + * Loading to zmm and FMA 16 of k at one loop, + * finally reduce_add zmm to a single float result in C(x, y). + * + * Note: performance is bad when K is small. + */ + FLOAT *mbuf = (FLOAT *) malloc(sizeof(FLOAT)*mm*K); + __mmask8 mask8 = (1UL << mm) - 1; + __mmask16 mask; + BLASLONG k16 = K & ~15; + BLASLONG k8 = K & ~7; + for (k = 0; k < k8; k += 8) { + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + r0 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(0 + k)]); + r1 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(1 + k)]); + r2 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(2 + k)]); + r3 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(3 + k)]); + r4 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(4 + k)]); + r5 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(5 + k)]); + r6 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(6 + k)]); + r7 = _mm256_maskz_loadu_ps(mask8, &A[i + lda*(7 + k)]); + + t0 = _mm256_unpacklo_ps(r0, r1); + t1 = _mm256_unpackhi_ps(r0, r1); + t2 = _mm256_unpacklo_ps(r2, r3); + t3 = _mm256_unpackhi_ps(r2, r3); + t4 = _mm256_unpacklo_ps(r4, r5); + t5 = _mm256_unpackhi_ps(r4, r5); + t6 = _mm256_unpacklo_ps(r6, r7); + t7 = _mm256_unpackhi_ps(r6, r7); + + r0 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(1,0,1,0)); + r1 = _mm256_shuffle_ps(t0,t2,_MM_SHUFFLE(3,2,3,2)); + r2 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(1,0,1,0)); + r3 = _mm256_shuffle_ps(t1,t3,_MM_SHUFFLE(3,2,3,2)); + r4 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(1,0,1,0)); + r5 = _mm256_shuffle_ps(t4,t6,_MM_SHUFFLE(3,2,3,2)); + r6 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(1,0,1,0)); + r7 = _mm256_shuffle_ps(t5,t7,_MM_SHUFFLE(3,2,3,2)); + + t0 = _mm256_permute2f128_ps(r0, r4, 0x20); + t1 = _mm256_permute2f128_ps(r1, r5, 0x20); + t2 = _mm256_permute2f128_ps(r2, r6, 0x20); + t3 = _mm256_permute2f128_ps(r3, r7, 0x20); + t4 = _mm256_permute2f128_ps(r0, r4, 0x31); + t5 = _mm256_permute2f128_ps(r1, r5, 0x31); + t6 = _mm256_permute2f128_ps(r2, r6, 0x31); + t7 = _mm256_permute2f128_ps(r3, r7, 0x31); + + switch (mm) { + case 8: _mm256_storeu_ps(&mbuf[k + 7*K], t7); + case 7: _mm256_storeu_ps(&mbuf[k + 6*K], t6); + case 6: _mm256_storeu_ps(&mbuf[k + 5*K], t5); + case 5: _mm256_storeu_ps(&mbuf[k + 4*K], t4); + case 4: _mm256_storeu_ps(&mbuf[k + 3*K], t3); + case 3: _mm256_storeu_ps(&mbuf[k + 2*K], t2); + case 2: _mm256_storeu_ps(&mbuf[k + 1*K], t1); + case 1: _mm256_storeu_ps(&mbuf[k + 0*K], t0); + } + } + for (; k < K; k++) { + for (int ii = 0; ii < mm; ii++) { + mbuf[k + ii*K] = A[i + lda*k + ii]; + } + } + int mi = 0; + mask8 = 0xff; // just use to avoid SSE instruction + __m128 alpha_128 = _mm_broadcast_ss(&alpha); +#if !defined(B0) + __m128 beta_128 = _mm_broadcast_ss(&beta); +#endif + __m128i vindex_n = _mm_set_epi32(ldc*3, ldc*2, ldc, 0); + for (; i < m4; i += 4, mi += 4) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_REDUCE_M4(0); + } + + } + for (; i < m2; i += 2, mi += 2) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + } + } + for (; i < M; i += 1, mi += 1) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_REDUCE_N4(0); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_REDUCE(0, 0); + STORE_REDUCE(0, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + STORE_REDUCE(0, 0); + } + } + free(mbuf); + } + return 0; +} diff --git a/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c new file mode 100644 index 000000000..a7d87f8c4 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_nt_skylakex.c @@ -0,0 +1,535 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() +#define LOAD_A_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[lda * k + i + (M*16)]) +#define MASK_LOAD_A_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[lda * k + i + (M*16)]) +#define BROADCAST_LOAD_B_512(M, N) __m512 Bval##N = _mm512_broadcastss_ps(_mm_load_ss(&B[ldb * k + j + N])) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) + +#define BROADCAST_LOAD_A_512(M, N) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[lda * k + i + M])) +#define LOAD_B_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[ldb * k + j + (N*16)]) +#define MASK_LOAD_B_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[ldb * k + j + (N*16)]) +#if defined(B0) +#define STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) +#define MASK_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4) +#else +#define STORE_512(M, N) \ + result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + asm("vfmadd231ps (%1), %2, %0": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512)); \ + _mm512_storeu_ps(&C[(j+N)*ldc + i + (M*16)], result##M##N) +#define MASK_STORE_512(M, N) \ + result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + asm("vfmadd231ps (%1), %2, %0 %{%3%}": "+v"(result##M##N):"r"(&C[(j+N)*ldc + i + (M*16)]), "v"(beta_512), "k"(mask)); \ + _mm512_mask_storeu_ps(&C[(j+N)*ldc + i + (M*16)], mask, result##M##N) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + __m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ + result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ + _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + __m512 tmp##M##N = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ + result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ + _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4); +#endif + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m64 = M & ~63; + BLASLONG m32 = M & ~31; + BLASLONG m16 = M & ~15; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n64 = N & ~63; + BLASLONG n32 = N & ~31; + BLASLONG n8 = N & ~7; + BLASLONG n6 = N - (N % 6); + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + + __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); +#if !defined(B0) + __m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); +#endif + + for (i = 0; i < m64; i += 64) { + for (j = 0; j < n6; j += 6) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); DECLARE_RESULT_512(2, 4); DECLARE_RESULT_512(3, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); DECLARE_RESULT_512(2, 5); DECLARE_RESULT_512(3, 5); + + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + MATMUL_512(0, 4); MATMUL_512(1, 4); MATMUL_512(2, 4); MATMUL_512(3, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); MATMUL_512(2, 5); MATMUL_512(3, 5); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + STORE_512(0, 2); STORE_512(1, 2); STORE_512(2, 2); STORE_512(3, 2); + STORE_512(0, 3); STORE_512(1, 3); STORE_512(2, 3); STORE_512(3, 3); + STORE_512(0, 4); STORE_512(1, 4); STORE_512(2, 4); STORE_512(3, 4); + STORE_512(0, 5); STORE_512(1, 5); STORE_512(2, 5); STORE_512(3, 5); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + STORE_512(0, 1); STORE_512(1, 1); STORE_512(2, 1); STORE_512(3, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); LOAD_A_512(2, x); LOAD_A_512(3, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_512(0, 0); STORE_512(1, 0); STORE_512(2, 0); STORE_512(3, 0); + } + } + for (; i < m32; i += 32) { + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + DECLARE_RESULT_512(0, 4); DECLARE_RESULT_512(1, 4); + DECLARE_RESULT_512(0, 5); DECLARE_RESULT_512(1, 5); + DECLARE_RESULT_512(0, 6); DECLARE_RESULT_512(1, 6); + DECLARE_RESULT_512(0, 7); DECLARE_RESULT_512(1, 7); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + MATMUL_512(0, 4); MATMUL_512(1, 4); + MATMUL_512(0, 5); MATMUL_512(1, 5); + MATMUL_512(0, 6); MATMUL_512(1, 6); + MATMUL_512(0, 7); MATMUL_512(1, 7); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + STORE_512(0, 4); STORE_512(1, 4); + STORE_512(0, 5); STORE_512(1, 5); + STORE_512(0, 6); STORE_512(1, 6); + STORE_512(0, 7); STORE_512(1, 7); + } + for (;j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + STORE_512(0, 2); STORE_512(1, 2); + STORE_512(0, 3); STORE_512(1, 3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_512(0, 0); STORE_512(1, 0); + STORE_512(0, 1); STORE_512(1, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); LOAD_A_512(1, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_512(0, 0); STORE_512(1, 0); + } + } + for (; i < m16; i += 16) { + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + DECLARE_RESULT_512(0, 6); + DECLARE_RESULT_512(0, 7); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + MATMUL_512(0, 6); + MATMUL_512(0, 7); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + STORE_512(0, 4); + STORE_512(0, 5); + STORE_512(0, 6); + STORE_512(0, 7); + } + for (; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_512(0, 0); + STORE_512(0, 1); + STORE_512(0, 2); + STORE_512(0, 3); + } + + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_512(0, 0); + STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + STORE_512(0, 0); + } + } + int mm = M - i; + if (mm >= 12) { + register __mmask16 mask asm("k1") = (1UL << mm) - 1; + for (j = 0; j < n8; j += 8) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + DECLARE_RESULT_512(0, 4); + DECLARE_RESULT_512(0, 5); + DECLARE_RESULT_512(0, 6); + DECLARE_RESULT_512(0, 7); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + BROADCAST_LOAD_B_512(x, 4); BROADCAST_LOAD_B_512(x, 5); + BROADCAST_LOAD_B_512(x, 6); BROADCAST_LOAD_B_512(x, 7); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + MATMUL_512(0, 4); + MATMUL_512(0, 5); + MATMUL_512(0, 6); + MATMUL_512(0, 7); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + MASK_STORE_512(0, 4); + MASK_STORE_512(0, 5); + MASK_STORE_512(0, 6); + MASK_STORE_512(0, 7); + } + for (; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + BROADCAST_LOAD_B_512(x, 2); BROADCAST_LOAD_B_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + MASK_STORE_512(0, 2); + MASK_STORE_512(0, 3); + } + + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); BROADCAST_LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + MASK_STORE_512(0, 0); + MASK_STORE_512(0, 1); + } + for (; j < N; j++) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + MASK_LOAD_A_512(0, x); + BROADCAST_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_STORE_512(0, 0); + } + } else if (mm > 0) { + int index_n[16]; + for (int ii = 0; ii < 16; ii++) { + index_n[ii] = ii * ldc; + } + __m512i vindex_n = _mm512_loadu_si512(index_n); + for (; i < m4; i += 4) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); SCATTER_STORE_512(2, 2); SCATTER_STORE_512(3, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); SCATTER_STORE_512(2, 3); SCATTER_STORE_512(3, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); SCATTER_STORE_512(2, 0); SCATTER_STORE_512(3, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); SCATTER_STORE_512(2, 1); SCATTER_STORE_512(3, 1); + } + __mmask16 mask = 0xffff; + for (; j < N; j += 16) { + int remains = N - j; + if (remains < 16) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); MASK_SCATTER_STORE_512(2, 0); MASK_SCATTER_STORE_512(3, 0); + } + } + for (; i < m2; i += 2) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + } + __mmask16 mask = 0xffff; + for (; j < N; j += 16) { + int remains = N - j; + if (remains < 16) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + LOAD_B_512(x, 2); + LOAD_B_512(x, 3); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + SCATTER_STORE_512(0, 2); + SCATTER_STORE_512(0, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); + LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + } + __mmask16 mask = 0xffff; + for (; j < N; j += 16) { + int remains = N - j; + if (remains < 16) mask = (1UL << remains) - 1; + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_SCATTER_STORE_512(0, 0); + } + } + } + return 0; +} diff --git a/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c b/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c new file mode 100644 index 000000000..cbf2374bd --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_permit_skylakex.c @@ -0,0 +1,53 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include "common.h" + +int CNAME(int transa, int transb, BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT beta) +{ + double MNK = (double) M * (double) N * (double) K; + if (MNK > 100.0*100.0*100.0) // disable for big size matrix + return 0; + // tuning for A transpose + if (transa) { + if (transb) { + /* TT kernel perform not good when: + * 1. K is too small. + */ + if (K < 4) return 0; + } else { + /* TN kernel perform not good when: + * 1. C matrix is too big + * 2. K is too small + */ + if (M * N > 1200 || K < 32) + return 0; + } + } + + return 1; +} diff --git a/kernel/x86_64/sgemm_small_kernel_tn_skylakex.c b/kernel/x86_64/sgemm_small_kernel_tn_skylakex.c new file mode 100644 index 000000000..5a9a4ea32 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_tn_skylakex.c @@ -0,0 +1,316 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include +#include + +#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) + +#define LOAD_KA_512(M, N) __m512 Aval##M = _mm512_loadu_ps(&A[(i + M)*lda + k]); +#define LOAD_KB_512(M, N) __m512 Bval##N = _mm512_loadu_ps(&B[(j + N)*ldb + k]) +#define MASK_LOAD_KA_512(M, N) __m512 Aval##M = _mm512_maskz_loadu_ps(mask, &A[(i + M)*lda + k]) +#define MASK_LOAD_KB_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[(j + N)*ldb + k]) + +#define REDUCE_4(rr0, rr1, rr2, rr3) \ + __m512 r0, r1, r2, r3, t0, t1, t2, t3;\ + r0 = _mm512_unpacklo_ps(rr0, rr1); r1 = _mm512_unpackhi_ps(rr0, rr1); \ + r2 = _mm512_unpacklo_ps(rr2, rr3); r3 = _mm512_unpackhi_ps(rr2, rr3); \ + t0 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(1, 0, 1, 0)); t1 = _mm512_shuffle_ps(r0, r2, _MM_SHUFFLE(3, 2, 3, 2)); \ + t2 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(1, 0, 1, 0)); t3 = _mm512_shuffle_ps(r1, r3, _MM_SHUFFLE(3, 2, 3, 2)); \ + r0 = _mm512_add_ps(t0, t1); r1 = _mm512_add_ps(t2, t3); t0 = _mm512_add_ps(r0, r1); \ + __m128 s0, s1, s2, s3; \ + s0 = _mm512_extractf32x4_ps(t0, 0); s1 = _mm512_extractf32x4_ps(t0, 1); s2 = _mm512_extractf32x4_ps(t0, 2); s3 = _mm512_extractf32x4_ps(t0, 3); \ + s0 = _mm_maskz_add_ps(mask8, s0, s1); s2 = _mm_maskz_add_ps(mask8, s2, s3); s0 = _mm_maskz_add_ps(mask8, s0, s2); \ + s0 = _mm_maskz_mul_ps(mask8, alpha_128, s0); + +#define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N) +#define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3) + +#if defined(B0) +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) +#define STORE_M4(N, s0) _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); +#define STORE_N4(M, s0) _mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); +#else +#define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_ps(result##M##N) + beta * C[(j+N)*ldc + i + M] +#define STORE_M4(N, s0) \ + asm("vfmadd231ps (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_128)); \ + _mm_mask_storeu_ps(&C[(j + N)*ldc + i], mask8, s0); + +#define STORE_N4(M, s0) \ + s0 = _mm_fmadd_ps(_mm_i32gather_ps(&C[j*ldc + i + M], vindex_n, 4), beta_128, s0); \ + _mm_i32scatter_ps(&C[j*ldc + i + M], vindex_n, s0, 4); +#endif +#define STORE_REDUCE_M4(N) {\ + REDUCE_M4(N) \ + STORE_M4(N, s0) \ +} +#define STORE_REDUCE_N4(M) {\ + REDUCE_N4(M) \ + STORE_N4(M, s0) \ +} + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n4 = N & ~3; + BLASLONG n2 = N & ~1; + + BLASLONG k16 = K & ~15; + + __mmask16 mask; + __mmask8 mask8 = 0xff; // just use to avoid SSE instruction + + __m128i vindex_n = _mm_set_epi32(ldc*3, ldc*2, ldc, 0); + __m128 alpha_128 = _mm_broadcast_ss(&alpha); +#if !defined(B0) + __m128 beta_128 = _mm_broadcast_ss(&beta); +#endif + for (i = 0; i < m4; i += 4) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + STORE_REDUCE_M4(0); + } + + } + for (; i < m2; i += 2) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + STORE_REDUCE_N4(0); STORE_REDUCE_N4(1); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + STORE_REDUCE(0, 1); STORE_REDUCE(1, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); LOAD_KA_512(1, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + STORE_REDUCE(0, 0); STORE_REDUCE(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n4; j += 4) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3); + + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + STORE_REDUCE_N4(0); + } + for (; j < n2; j += 2) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); + + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + STORE_REDUCE(0, 0); + STORE_REDUCE(0, 1); + + } + for (; j < N; j += 1) { + DECLARE_RESULT_512(0, 0); + for (k = 0; k < k16; k += 16) { + LOAD_KA_512(0, x); + LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + int remains = K - k; + if (remains) { + mask = (1UL << remains) - 1; + MASK_LOAD_KA_512(0, x); + MASK_LOAD_KB_512(x, 0); + + MATMUL_512(0, 0); + } + STORE_REDUCE(0, 0); + } + } + return 0; +} diff --git a/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c b/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c new file mode 100644 index 000000000..023f58746 --- /dev/null +++ b/kernel/x86_64/sgemm_small_kernel_tt_skylakex.c @@ -0,0 +1,414 @@ +/*************************************************************************** +Copyright (c) 2021, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#include +#include "common.h" +#include + +#define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps() +#define BROADCAST_LOAD_A_512(M, N) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[k + lda * (i+M)])) +#define LOAD_B_512(M,N) __m512 Bval##N = _mm512_loadu_ps(&B[ldb * k + j + (N*16)]) +#define MASK_LOAD_B_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[ldb * k + j + (N*16)]) +#define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N) + +#if defined(B0) +#define STORE_8xy(v, N, x, y) _mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v) +#define STORE_4xy(v, N, x, y) _mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4); +#else +#define STORE_8xy(v, N, x, y) \ + asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*8)*ldc + i]), "v"(beta_256)); \ + _mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v) +#define STORE_4xy(v, N, x, y) \ + asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*4)*ldc + i]), "v"(beta_128)); \ + _mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v) +#define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + __m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ + result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ + _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4); +#define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \ + __m512 tmp##M##N = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, vindex_n, &C[(j + N*16)*ldc + i + M], 4); \ + result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \ + _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4); +#endif + +#define REORDER_8x16(r0, r1, r2, r3, r4, r5, r6, r7) \ + __m512 t0, t1, t2, t3, t4, t5, t6, t7, v; \ + t0 = _mm512_unpacklo_ps(r0, r1); \ + t1 = _mm512_unpackhi_ps(r0, r1); \ + t2 = _mm512_unpacklo_ps(r2, r3); \ + t3 = _mm512_unpackhi_ps(r2, r3); \ + t4 = _mm512_unpacklo_ps(r4, r5); \ + t5 = _mm512_unpackhi_ps(r4, r5); \ + t6 = _mm512_unpacklo_ps(r6, r7); \ + t7 = _mm512_unpackhi_ps(r6, r7); \ + v = _mm512_shuffle_ps(t0, t2, 0x4E); \ + r0 = _mm512_mask_blend_ps(kc, t0, v); \ + r1 = _mm512_mask_blend_ps(k3, t2, v); \ + v = _mm512_shuffle_ps(t1, t3, 0x4E); \ + r2 = _mm512_mask_blend_ps(kc, t1, v); \ + r3 = _mm512_mask_blend_ps(k3, t3, v); \ + v = _mm512_shuffle_ps(t4, t6, 0x4E); \ + r4 = _mm512_mask_blend_ps(kc, t4, v); \ + r5 = _mm512_mask_blend_ps(k3, t6, v); \ + v = _mm512_shuffle_ps(t5, t7, 0x4E); \ + r6 = _mm512_mask_blend_ps(kc, t5, v); \ + r7 = _mm512_mask_blend_ps(k3, t7, v); \ + t0 = _mm512_permutex2var_ps(r0, idx_lo, r4); \ + t1 = _mm512_permutex2var_ps(r1, idx_lo, r5); \ + t2 = _mm512_permutex2var_ps(r2, idx_lo, r6); \ + t3 = _mm512_permutex2var_ps(r3, idx_lo, r7); \ + t4 = _mm512_permutex2var_ps(r0, idx_hi, r4); \ + t5 = _mm512_permutex2var_ps(r1, idx_hi, r5); \ + t6 = _mm512_permutex2var_ps(r2, idx_hi, r6); \ + t7 = _mm512_permutex2var_ps(r3, idx_hi, r7); \ + t0 = _mm512_mul_ps(t0, alpha_512); \ + t1 = _mm512_mul_ps(t1, alpha_512); \ + t2 = _mm512_mul_ps(t2, alpha_512); \ + t3 = _mm512_mul_ps(t3, alpha_512); \ + t4 = _mm512_mul_ps(t4, alpha_512); \ + t5 = _mm512_mul_ps(t5, alpha_512); \ + t6 = _mm512_mul_ps(t6, alpha_512); \ + t7 = _mm512_mul_ps(t7, alpha_512); + +#define SAVE_8(N, x, y) {\ + __m256 v8 = _mm512_extractf32x8_ps(t##x, y); \ + STORE_8xy(v8, N, x, y); \ +} + +#define REORDER_STORE_8x16(N) {\ + REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ + SAVE_8(N, 0, 0); SAVE_8(N, 1, 0); SAVE_8(N, 2, 0); SAVE_8(N, 3, 0); SAVE_8(N, 4, 0); SAVE_8(N, 5, 0); SAVE_8(N, 6, 0); SAVE_8(N, 7, 0); \ + SAVE_8(N, 0, 1); SAVE_8(N, 1, 1); SAVE_8(N, 2, 1); SAVE_8(N, 3, 1); SAVE_8(N, 4, 1); SAVE_8(N, 5, 1); SAVE_8(N, 6, 1); SAVE_8(N, 7, 1); \ +} + +#define MASK_SAVE_8() \ + switch (nn) { \ + case 16: SAVE_8(0, 7, 1); \ + case 15: SAVE_8(0, 6, 1); \ + case 14: SAVE_8(0, 5, 1); \ + case 13: SAVE_8(0, 4, 1); \ + case 12: SAVE_8(0, 3, 1); \ + case 11: SAVE_8(0, 2, 1); \ + case 10: SAVE_8(0, 1, 1); \ + case 9: SAVE_8(0, 0, 1); \ + case 8: SAVE_8(0, 7, 0); \ + case 7: SAVE_8(0, 6, 0); \ + case 6: SAVE_8(0, 5, 0); \ + case 5: SAVE_8(0, 4, 0); \ + case 4: SAVE_8(0, 3, 0); \ + case 3: SAVE_8(0, 2, 0); \ + case 2: SAVE_8(0, 1, 0); \ + case 1: SAVE_8(0, 0, 0); \ + } + +#define MASK_REORDER_STORE_8x16(N) {\ + REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \ + MASK_SAVE_8(); \ +} + +#define REORDER_4x16(r0, r1, r2, r3) \ + __m512 t0, t1, t2, t3, v; \ + t0 = _mm512_unpacklo_ps(r0, r1); \ + t1 = _mm512_unpackhi_ps(r0, r1); \ + t2 = _mm512_unpacklo_ps(r2, r3); \ + t3 = _mm512_unpackhi_ps(r2, r3); \ + v = _mm512_shuffle_ps(t0, t2, 0x4E); \ + r0 = _mm512_mask_blend_ps(kc, t0, v); \ + r1 = _mm512_mask_blend_ps(k3, t2, v); \ + v = _mm512_shuffle_ps(t1, t3, 0x4E); \ + r2 = _mm512_mask_blend_ps(kc, t1, v); \ + r3 = _mm512_mask_blend_ps(k3, t3, v); \ + t0 = _mm512_mul_ps(r0, alpha_512); \ + t1 = _mm512_mul_ps(r1, alpha_512); \ + t2 = _mm512_mul_ps(r2, alpha_512); \ + t3 = _mm512_mul_ps(r3, alpha_512); + +#define SAVE_4(N, x, y) {\ + __m128 v4 = _mm512_extractf32x4_ps(t##x, y); \ + STORE_4xy(v4, N, x, y); \ +} + +#define REORDER_STORE_4x16(N) {\ + REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \ + SAVE_4(N, 0, 0); SAVE_4(N, 1, 0); SAVE_4(N, 2, 0); SAVE_4(N, 3, 0); \ + SAVE_4(N, 0, 1); SAVE_4(N, 1, 1); SAVE_4(N, 2, 1); SAVE_4(N, 3, 1); \ + SAVE_4(N, 0, 2); SAVE_4(N, 1, 2); SAVE_4(N, 2, 2); SAVE_4(N, 3, 2); \ + SAVE_4(N, 0, 3); SAVE_4(N, 1, 3); SAVE_4(N, 2, 3); SAVE_4(N, 3, 3); \ +} + +#define MASK_SAVE_4() \ + switch (nn) { \ + case 16: SAVE_4(0, 3, 3); \ + case 15: SAVE_4(0, 2, 3); \ + case 14: SAVE_4(0, 1, 3); \ + case 13: SAVE_4(0, 0, 3); \ + case 12: SAVE_4(0, 3, 2); \ + case 11: SAVE_4(0, 2, 2); \ + case 10: SAVE_4(0, 1, 2); \ + case 9: SAVE_4(0, 0, 2); \ + case 8: SAVE_4(0, 3, 1); \ + case 7: SAVE_4(0, 2, 1); \ + case 6: SAVE_4(0, 1, 1); \ + case 5: SAVE_4(0, 0, 1); \ + case 4: SAVE_4(0, 3, 0); \ + case 3: SAVE_4(0, 2, 0); \ + case 2: SAVE_4(0, 1, 0); \ + case 1: SAVE_4(0, 0, 0); \ + } + +#define MASK_REORDER_STORE_4x16(N) {\ + REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \ + MASK_SAVE_4(); \ +} + + +#if defined(B0) +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc) +#else +int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc) +#endif +{ + // column major + BLASLONG i, j, k; + + BLASLONG m8 = M & ~7; + BLASLONG m4 = M & ~3; + BLASLONG m2 = M & ~1; + + BLASLONG n64 = N & ~63; + BLASLONG n32 = N & ~31; + + __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha)); +#if !defined(B0) + __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta)); + __m128 beta_128 = _mm_broadcastss_ps(_mm_load_ss(&beta)); +#endif + int permute_table[] = { + 0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b, + 0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f, + }; + __m512i idx_lo = _mm512_loadu_si512(permute_table); + __m512i idx_hi = _mm512_loadu_si512(permute_table + 16); + __mmask16 kc = 0xcccc; + __mmask16 k3 = 0x3333; + __mmask8 mask8 = 0xff; // force use AVX128 instead of SSE + + for (i = 0; i < m8; i += 8) { + for (j = 0; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); + + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(4, 1); DECLARE_RESULT_512(5, 1); DECLARE_RESULT_512(6, 1); DECLARE_RESULT_512(7, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(4, 1); MATMUL_512(5, 1); MATMUL_512(6, 1); MATMUL_512(7, 1); + } + REORDER_STORE_8x16(0); + REORDER_STORE_8x16(1); + } + __mmask16 mask = 0xffff; + int nn = 16; + for (; j < N; j += 16) { + if (N - j < 16) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0); + } + MASK_REORDER_STORE_8x16(0); + } + } + for (; i < m4; i += 4) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3); + } + REORDER_STORE_4x16(0); + REORDER_STORE_4x16(1); + REORDER_STORE_4x16(2); + REORDER_STORE_4x16(3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1); + } + REORDER_STORE_4x16(0); + REORDER_STORE_4x16(1); + } + __mmask16 mask = 0xffff; + int nn = 16; + for (; j < N; j += 16) { + if (N - j < 16) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0); + } + MASK_REORDER_STORE_4x16(0); + } + } + if (i < M) { + int index_n[16]; + for (int ii = 0; ii < 16; ii++) { + index_n[ii] = ii * ldc; + } + __m512i vindex_n = _mm512_loadu_si512(index_n); +#if !defined(B0) + __m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta)); +#endif + for (; i < m2; i += 2) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); + DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + MATMUL_512(0, 2); MATMUL_512(1, 2); + MATMUL_512(0, 3); MATMUL_512(1, 3); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2); + SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); MATMUL_512(1, 0); + MATMUL_512(0, 1); MATMUL_512(1, 1); + } + SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0); + SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1); + } + __mmask16 mask = 0xffff; + int nn = 16; + for (; j < N; j += 16) { + if (N - j < 16) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); MATMUL_512(1, 0); + } + MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0); + } + } + for (; i < M; i += 1) { + for (j = 0; j < n64; j += 64) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + DECLARE_RESULT_512(0, 2); + DECLARE_RESULT_512(0, 3); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + MATMUL_512(0, 2); + MATMUL_512(0, 3); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + SCATTER_STORE_512(0, 2); + SCATTER_STORE_512(0, 3); + } + for (; j < n32; j += 32) { + DECLARE_RESULT_512(0, 0); + DECLARE_RESULT_512(0, 1); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + LOAD_B_512(x, 0); LOAD_B_512(x, 1); + MATMUL_512(0, 0); + MATMUL_512(0, 1); + } + SCATTER_STORE_512(0, 0); + SCATTER_STORE_512(0, 1); + } + __mmask16 mask = 0xffff; + int nn = 16; + for (; j < N; j += 16) { + if (N - j < 16) { + nn = N - j; + mask = (1UL << nn) - 1; + } + DECLARE_RESULT_512(0, 0); + for (k = 0; k < K; k++) { + BROADCAST_LOAD_A_512(0, x); + MASK_LOAD_B_512(x, 0); + MATMUL_512(0, 0); + } + MASK_SCATTER_STORE_512(0, 0); + } + } + } + return 0; +} diff --git a/kernel/x86_64/sgemv_n_4.c b/kernel/x86_64/sgemv_n_4.c index 06de28d97..e0778006f 100644 --- a/kernel/x86_64/sgemv_n_4.c +++ b/kernel/x86_64/sgemv_n_4.c @@ -115,6 +115,8 @@ static void sgemv_kernel_4x4(BLASLONG n, FLOAT **ap, FLOAT *xo, FLOAT *y, FLOAT #endif +#ifndef HAVE_SGEMV_N_SKYLAKE_KERNEL + #ifndef HAVE_KERNEL_4x2 static void sgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT *alpha) __attribute__ ((noinline)); @@ -170,6 +172,7 @@ static void sgemv_kernel_4x2( BLASLONG n, FLOAT **ap, FLOAT *x, FLOAT *y, FLOAT } +#endif #endif #ifndef HAVE_KERNEL_4x1 @@ -302,9 +305,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO FLOAT * xbuffer_align = x; FLOAT * ybuffer_align = y; - FLOAT * xbuffer = NULL; - FLOAT * ybuffer = NULL; - if (inc_x != 1) { xbuffer_align = buffer; for(BLASLONG i=0; i> (16-((m-tag_m_8x)*2)&15)); + unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-(((m-tag_m_8x)*2)&15))); __mmask16 a_mask = *((__mmask16*) &tail_mask_value); unsigned char y_mask_value = (((unsigned char)0xff) >> (8-(m-tag_m_8x))); __mmask8 y_mask = *((__mmask8*) &y_mask_value); - m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x]); + m0 = _mm512_maskz_loadu_ps(a_mask, &a[tag_m_8x*2]); m1 = _mm512_mul_ps(_mm512_mul_ps(m0, x1Array), ALPHAVECTOR); m2 = _mm512_permutexvar_ps(_mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0), m1); __m256 ret = _mm256_add_ps(_mm512_extractf32x8_ps(m2, 1), _mm512_extractf32x8_ps(m2, 0)); @@ -322,7 +322,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float * { BLASLONG tag_m_4x = m & (~3); BLASLONG tag_m_2x = m & (~1); - __m512 m0, m1, m2; + __m512 m0, m1; __m256 m256_0, m256_1, c256_1, c256_2; __m128 c1, c2, c3, c4, ret; __m128 xarray = _mm_maskz_loadu_ps(0x0f, x); @@ -346,7 +346,7 @@ static int sgemv_kernel_t_4(BLASLONG m, float alpha, float *a, float *x, float * c3 = _mm256_extractf32x4_ps(c256_2, 0); c4 = _mm256_extractf32x4_ps(c256_2, 1); - ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, y)); + ret = _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, _mm_maskz_add_ps(0xff, c1, c2), _mm_maskz_add_ps(0xff, c3, c4)), _mm_maskz_loadu_ps(0xff, &y[idx_m])); _mm_mask_storeu_ps(&y[idx_m], 0xff, ret); } @@ -958,6 +958,7 @@ static int sgemv_kernel_t_7(BLASLONG m, float alpha, float *a, float *x, float * c256_1 = _mm512_extractf32x8_ps(tmp0, 1); c256_0 = _mm256_add_ps(c256_0, c256_1); + c256_0 = _mm256_mul_ps(c256_0, alpha256); __m128 c128_0 = _mm256_extractf32x4_ps(c256_0, 0); __m128 c128_1 = _mm256_extractf32x4_ps(c256_0, 1); @@ -1016,9 +1017,10 @@ static int sgemv_kernel_t_8(BLASLONG m, float alpha, float *a, float *x, float * __m512 m0, m1, m2, m3; __m256 r0, r1, r2, r3, r4, r5, r6, r7, tmp0, tmp1, tmp2, tmp3; __m128 c128_0, c128_1, c128_2, c128_3; - __m128 alpha128 = _mm_set1_ps(alpha); + __m256 alpha256 = _mm256_set1_ps(alpha); __m256 x256 = _mm256_loadu_ps(x); + x256 = _mm256_mul_ps(x256, alpha256); __m512 x512 = _mm512_broadcast_f32x8(x256); for(BLASLONG idx_m=0; idx_m -#ifdef __CTEST_MSVC +#ifdef _WIN32 #include #else #include