ENH: Add TRMM_KERNEL bindings

This commit is contained in:
Rohit Goswami 2024-05-26 00:00:12 +00:00 committed by Mateusz Sokół
parent 854aecce82
commit 571d2f3be3
3 changed files with 141 additions and 31 deletions

View File

@ -450,8 +450,8 @@ endforeach
# Create the static libraries from the configurations
_kern_libs = []
foreach conf : kernel_confs
message(conf['name'])
message(conf)
# message(conf['name'])
# message(conf)
_kern_libs += [static_library(
conf['name'],
conf['src'],

View File

@ -357,16 +357,101 @@ base_kops = [
},
# Level 3 symbols
# TODO(rg): Handle the if defines set in arch
# { 'base': '?gemm_beta', # This is a bfloat16 symbol, skipping for now
# 'modes': {
# 's': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
# 'd': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
# 'c': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
# 'z': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
# # 'q': {'exts': {'': {'dir': 'generic', 'kernel': 'gemm_beta.c'}}},
# # 'x': {'exts': {'': {'dir': 'generic', 'kernel': 'zgemm_beta.c'}}},
# },
# },
{ 'base': '?gemm_beta',
'modes': {
's': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
'd': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
'c': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
'z': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
# 'q': {'exts': {'': {'dir': 'generic', 'kernel': 'gemm_beta.c'}}},
# 'x': {'exts': {'': {'dir': 'generic', 'kernel': 'zgemm_beta.c'}}},
},
},
{ 'base': '?gemm_kernel',
'modes': {
's': {'exts': {'': {'dir': 'generic', 'kernel': 'gemmkernel_2x2.c'}}},
'd': {'exts': {'': {'dir': 'generic', 'kernel': 'gemmkernel_2x2.c'}}},
'c': {
'exts': {
'_n': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNN']},
'_l': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCN']},
# TODO(rg): What about _r conditionals? Makefile.L3:2969
'_r': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNC']},
'_b': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCC']},
}
},
'z': {
'exts': {
'_n': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNN']},
'_l': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCN']},
'_r': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNC']},
'_b': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCC']},
}
}
# 'q': {'exts': {'': {'dir': 'generic', 'kernel': 'gemm_beta.c'}}},
# 'x': {'exts': {'': {'dir': 'generic', 'kernel': 'zgemm_beta.c'}}},
},
},
{ 'base': '?trmm_kernel',
'modes': {
's': {
'exts': {
'_LN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
'_LT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c', 'addl': ['-DLEFT', '-DTRANSA']},
'_RN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
'_RT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
}
},
'd': {
'exts': {
'_LN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
'_LT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c', 'addl': ['-DLEFT', '-DTRANSA']},
'_RN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
'_RT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
}
},
'c': {
'exts': {
'_LN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-UCONJ', '-DNN']},
'_LT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DLEFT', '-DTRANSA', '-UCONJ', '-DNN']},
'_LR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DCONJ', '-DCN']},
'_LC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DCONJ', '-DCN']},
'_RN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-UCONJ', '-DNN']},
'_RT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-ULEFT', '-DTRANSA', '-UCONJ', '-DNN']},
'_RR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DCONJ', '-DNC']},
'_RC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DCONJ', '-DCN']},
}
},
'z': {
'exts': {
'_LN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-UCONJ', '-DNN']},
'_LT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DLEFT', '-DTRANSA', '-UCONJ', '-DNN']},
'_LR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DCONJ', '-DCN']},
'_LC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DCONJ', '-DCN']},
'_RN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-UCONJ', '-DNN']},
'_RT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-ULEFT', '-DTRANSA', '-UCONJ', '-DNN']},
'_RR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DCONJ', '-DNC']},
'_RC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
'addl': ['-DCONJ', '-DCN']},
}
},
},
},
]
kernel_confs = []
@ -394,9 +479,9 @@ foreach _kop : base_kops
prec_mode = precision_mappings[mode]
# Generate the mapping for the type
if prec_mode.has_key('def')
foreach _d : prec_mode['def']
__cargs += ('-D' + _d)
endforeach
foreach _d : prec_mode['def']
__cargs += ('-D' + _d)
endforeach
endif
if prec_mode.has_key('undef')
foreach _u : prec_mode['undef']
@ -406,18 +491,38 @@ foreach _kop : base_kops
# Now the rest, one run for each ext, to get the final symbols
foreach ext, extdat : details['exts']
_ext_cargs = [] # Will be wiped for each ext preventing redefinitions
extmap = ext_mappings[ext]
if extmap.has_key('def')
foreach _d : extmap['def']
_ext_cargs += ('-D' + _d)
endforeach
endif
if extmap.has_key('undef')
foreach _u : extmap['undef']
_ext_cargs += ('-U' + _u)
endforeach
endif
# Construct the final paths
# Check ext_mappings first
if ext_mappings.has_key(ext) and (not ext_mappings.has_key('except') or base not in ext_mappings['except'])
extmap = ext_mappings[ext]
if extmap.has_key('def')
foreach _d : extmap['def']
_ext_cargs += ['-D' + _d]
endforeach
endif
if extmap.has_key('undef')
foreach _u : extmap['undef']
_ext_cargs += ['-U' + _u]
endforeach
endif
else
# Fallback to ext_mappings_l2
foreach ext_map : ext_mappings_l2 + ext_mappings_l3
if ext_map['ext'] == ext and mode in ext_map['for'] and (not ext_map.has_key('except') or base not in ext_map['except'])
if ext_map.has_key('def')
foreach _d : ext_map['def']
_ext_cargs += ['-D' + _d]
endforeach
endif
if ext_map.has_key('undef')
foreach _u : ext_map['undef']
_ext_cargs += ['-U' + _u]
endforeach
endif
break
endif
endforeach
endif
src = join_paths(extdat['dir'], extdat['kernel'])
if extdat.has_key('addl')
_ext_cargs += extdat['addl']
@ -444,8 +549,8 @@ endforeach
_kern_libs = []
foreach conf: kernel_confs
# message(conf['name'])
# message(conf)
message(conf['name'])
message(conf)
_kern_libs += static_library(
conf['name'],
conf['src'],

View File

@ -303,6 +303,10 @@ ext_mappings = {
'_LL': {'def': ['LOWER', 'NN'], 'undef': ['RSIDE']},
'_RU': {'def': ['RSIDE', 'NN'], 'undef': ['LOWER'], 'except': ['?hemm', '?hemm_thread']},
'_RL': {'def': ['RSIDE', 'NN', 'LOWER'], 'except': ['?hemm', '?hemm_thread']},
'_RN': {'undef': ['LEFT', 'TRANSA']},
'_RR': {'undef': ['LEFT', 'TRANSA']},
'_RT': {'def': ['TRANSA'], 'undef': ['LEFT']},
'_RC': {'def': ['TRANSA'], 'undef': ['LEFT']},
# TODO(rg): is CONJ OK for interface symbols?
'_UN': {'undef': ['TRANS', 'LOWER', 'CONJ'], 'except': ['?syrk']},
'_UT': {'def': ['TRANS'], 'undef': ['LOWER'], 'except': ['?syrk']},
@ -393,8 +397,8 @@ ext_mappings_l3 = [
# syrk
{'ext': '_UN', 'def': [], 'undef': ['LOWER', 'TRANS'], 'for': ['s', 'd', 'c', 'z']},
{'ext': '_UT', 'def': ['TRANS'], 'undef': ['LOWER'], 'for': ['s', 'd', 'c', 'z']},
{'ext': '_LN', 'def': ['LOWER'], 'undef': ['TRANS', 'CONJ'], 'for': ['s', 'd', 'c', 'z']},
{'ext': '_LT', 'def': ['TRANS', 'LOWER'], 'for': ['s', 'd', 'c', 'z']},
{'ext': '_LN', 'def': ['LOWER'], 'undef': ['TRANS', 'CONJ'], 'for': ['s', 'd', 'c', 'z'], 'except': ['?trmm_kernel']},
{'ext': '_LT', 'def': ['TRANS', 'LOWER'], 'for': ['s', 'd', 'c', 'z'], 'except': ['?trmm_kernel']},
{'ext': '_RU', 'def': ['RSIDE', 'NC'], 'undef': ['LOWER'], 'for': ['c', 'z']},
{'ext': '_RL', 'def': ['RSIDE', 'NC', 'LOWER'], 'for': ['c', 'z']},
]
@ -423,6 +427,7 @@ symb_defs = {
'?her_thread': {'def': ['HER']},
'?her2_thread': {'def': ['HER']},
'?hpr_thread': {'def': ['HEMV']},
'?trmm_kernel': {'def': ['TRMMKERNEL']},
'?bgemm': {'def': ['HALF']},
'cblas_?dotu_sub': {'def': ['CBLAS', 'FORCE_USE_STACK'], 'undef': ['CONJ']},
'cblas_?dotc_sub': {'def': ['CBLAS', 'FORCE_USE_STACK', 'CONJ']},