ENH: Add TRMM_KERNEL bindings
This commit is contained in:
parent
854aecce82
commit
571d2f3be3
|
@ -450,8 +450,8 @@ endforeach
|
|||
# Create the static libraries from the configurations
|
||||
_kern_libs = []
|
||||
foreach conf : kernel_confs
|
||||
message(conf['name'])
|
||||
message(conf)
|
||||
# message(conf['name'])
|
||||
# message(conf)
|
||||
_kern_libs += [static_library(
|
||||
conf['name'],
|
||||
conf['src'],
|
||||
|
|
|
@ -357,16 +357,101 @@ base_kops = [
|
|||
},
|
||||
# Level 3 symbols
|
||||
# TODO(rg): Handle the if defines set in arch
|
||||
# { 'base': '?gemm_beta', # This is a bfloat16 symbol, skipping for now
|
||||
# 'modes': {
|
||||
# 's': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
|
||||
# 'd': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
|
||||
# 'c': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
|
||||
# 'z': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
|
||||
# # 'q': {'exts': {'': {'dir': 'generic', 'kernel': 'gemm_beta.c'}}},
|
||||
# # 'x': {'exts': {'': {'dir': 'generic', 'kernel': 'zgemm_beta.c'}}},
|
||||
# },
|
||||
# },
|
||||
{ 'base': '?gemm_beta',
|
||||
'modes': {
|
||||
's': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
|
||||
'd': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
|
||||
'c': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
|
||||
'z': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
|
||||
# 'q': {'exts': {'': {'dir': 'generic', 'kernel': 'gemm_beta.c'}}},
|
||||
# 'x': {'exts': {'': {'dir': 'generic', 'kernel': 'zgemm_beta.c'}}},
|
||||
},
|
||||
},
|
||||
{ 'base': '?gemm_kernel',
|
||||
'modes': {
|
||||
's': {'exts': {'': {'dir': 'generic', 'kernel': 'gemmkernel_2x2.c'}}},
|
||||
'd': {'exts': {'': {'dir': 'generic', 'kernel': 'gemmkernel_2x2.c'}}},
|
||||
'c': {
|
||||
'exts': {
|
||||
'_n': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNN']},
|
||||
'_l': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCN']},
|
||||
# TODO(rg): What about _r conditionals? Makefile.L3:2969
|
||||
'_r': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNC']},
|
||||
'_b': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCC']},
|
||||
}
|
||||
},
|
||||
'z': {
|
||||
'exts': {
|
||||
'_n': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNN']},
|
||||
'_l': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCN']},
|
||||
'_r': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNC']},
|
||||
'_b': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCC']},
|
||||
}
|
||||
}
|
||||
# 'q': {'exts': {'': {'dir': 'generic', 'kernel': 'gemm_beta.c'}}},
|
||||
# 'x': {'exts': {'': {'dir': 'generic', 'kernel': 'zgemm_beta.c'}}},
|
||||
},
|
||||
},
|
||||
{ 'base': '?trmm_kernel',
|
||||
'modes': {
|
||||
's': {
|
||||
'exts': {
|
||||
'_LN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||
'_LT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c', 'addl': ['-DLEFT', '-DTRANSA']},
|
||||
'_RN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||
'_RT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||
}
|
||||
},
|
||||
'd': {
|
||||
'exts': {
|
||||
'_LN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||
'_LT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c', 'addl': ['-DLEFT', '-DTRANSA']},
|
||||
'_RN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||
'_RT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||
}
|
||||
},
|
||||
'c': {
|
||||
'exts': {
|
||||
'_LN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-UCONJ', '-DNN']},
|
||||
'_LT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DLEFT', '-DTRANSA', '-UCONJ', '-DNN']},
|
||||
'_LR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DCONJ', '-DCN']},
|
||||
'_LC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DCONJ', '-DCN']},
|
||||
'_RN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-UCONJ', '-DNN']},
|
||||
'_RT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-ULEFT', '-DTRANSA', '-UCONJ', '-DNN']},
|
||||
'_RR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DCONJ', '-DNC']},
|
||||
'_RC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DCONJ', '-DCN']},
|
||||
}
|
||||
},
|
||||
'z': {
|
||||
'exts': {
|
||||
'_LN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-UCONJ', '-DNN']},
|
||||
'_LT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DLEFT', '-DTRANSA', '-UCONJ', '-DNN']},
|
||||
'_LR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DCONJ', '-DCN']},
|
||||
'_LC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DCONJ', '-DCN']},
|
||||
'_RN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-UCONJ', '-DNN']},
|
||||
'_RT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-ULEFT', '-DTRANSA', '-UCONJ', '-DNN']},
|
||||
'_RR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DCONJ', '-DNC']},
|
||||
'_RC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||
'addl': ['-DCONJ', '-DCN']},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
kernel_confs = []
|
||||
|
@ -394,9 +479,9 @@ foreach _kop : base_kops
|
|||
prec_mode = precision_mappings[mode]
|
||||
# Generate the mapping for the type
|
||||
if prec_mode.has_key('def')
|
||||
foreach _d : prec_mode['def']
|
||||
__cargs += ('-D' + _d)
|
||||
endforeach
|
||||
foreach _d : prec_mode['def']
|
||||
__cargs += ('-D' + _d)
|
||||
endforeach
|
||||
endif
|
||||
if prec_mode.has_key('undef')
|
||||
foreach _u : prec_mode['undef']
|
||||
|
@ -406,18 +491,38 @@ foreach _kop : base_kops
|
|||
# Now the rest, one run for each ext, to get the final symbols
|
||||
foreach ext, extdat : details['exts']
|
||||
_ext_cargs = [] # Will be wiped for each ext preventing redefinitions
|
||||
extmap = ext_mappings[ext]
|
||||
if extmap.has_key('def')
|
||||
foreach _d : extmap['def']
|
||||
_ext_cargs += ('-D' + _d)
|
||||
endforeach
|
||||
endif
|
||||
if extmap.has_key('undef')
|
||||
foreach _u : extmap['undef']
|
||||
_ext_cargs += ('-U' + _u)
|
||||
endforeach
|
||||
endif
|
||||
# Construct the final paths
|
||||
# Check ext_mappings first
|
||||
if ext_mappings.has_key(ext) and (not ext_mappings.has_key('except') or base not in ext_mappings['except'])
|
||||
extmap = ext_mappings[ext]
|
||||
if extmap.has_key('def')
|
||||
foreach _d : extmap['def']
|
||||
_ext_cargs += ['-D' + _d]
|
||||
endforeach
|
||||
endif
|
||||
if extmap.has_key('undef')
|
||||
foreach _u : extmap['undef']
|
||||
_ext_cargs += ['-U' + _u]
|
||||
endforeach
|
||||
endif
|
||||
else
|
||||
# Fallback to ext_mappings_l2
|
||||
foreach ext_map : ext_mappings_l2 + ext_mappings_l3
|
||||
if ext_map['ext'] == ext and mode in ext_map['for'] and (not ext_map.has_key('except') or base not in ext_map['except'])
|
||||
if ext_map.has_key('def')
|
||||
foreach _d : ext_map['def']
|
||||
_ext_cargs += ['-D' + _d]
|
||||
endforeach
|
||||
endif
|
||||
if ext_map.has_key('undef')
|
||||
foreach _u : ext_map['undef']
|
||||
_ext_cargs += ['-U' + _u]
|
||||
endforeach
|
||||
endif
|
||||
break
|
||||
endif
|
||||
endforeach
|
||||
endif
|
||||
|
||||
src = join_paths(extdat['dir'], extdat['kernel'])
|
||||
if extdat.has_key('addl')
|
||||
_ext_cargs += extdat['addl']
|
||||
|
@ -444,8 +549,8 @@ endforeach
|
|||
|
||||
_kern_libs = []
|
||||
foreach conf: kernel_confs
|
||||
# message(conf['name'])
|
||||
# message(conf)
|
||||
message(conf['name'])
|
||||
message(conf)
|
||||
_kern_libs += static_library(
|
||||
conf['name'],
|
||||
conf['src'],
|
||||
|
|
|
@ -303,6 +303,10 @@ ext_mappings = {
|
|||
'_LL': {'def': ['LOWER', 'NN'], 'undef': ['RSIDE']},
|
||||
'_RU': {'def': ['RSIDE', 'NN'], 'undef': ['LOWER'], 'except': ['?hemm', '?hemm_thread']},
|
||||
'_RL': {'def': ['RSIDE', 'NN', 'LOWER'], 'except': ['?hemm', '?hemm_thread']},
|
||||
'_RN': {'undef': ['LEFT', 'TRANSA']},
|
||||
'_RR': {'undef': ['LEFT', 'TRANSA']},
|
||||
'_RT': {'def': ['TRANSA'], 'undef': ['LEFT']},
|
||||
'_RC': {'def': ['TRANSA'], 'undef': ['LEFT']},
|
||||
# TODO(rg): is CONJ OK for interface symbols?
|
||||
'_UN': {'undef': ['TRANS', 'LOWER', 'CONJ'], 'except': ['?syrk']},
|
||||
'_UT': {'def': ['TRANS'], 'undef': ['LOWER'], 'except': ['?syrk']},
|
||||
|
@ -393,8 +397,8 @@ ext_mappings_l3 = [
|
|||
# syrk
|
||||
{'ext': '_UN', 'def': [], 'undef': ['LOWER', 'TRANS'], 'for': ['s', 'd', 'c', 'z']},
|
||||
{'ext': '_UT', 'def': ['TRANS'], 'undef': ['LOWER'], 'for': ['s', 'd', 'c', 'z']},
|
||||
{'ext': '_LN', 'def': ['LOWER'], 'undef': ['TRANS', 'CONJ'], 'for': ['s', 'd', 'c', 'z']},
|
||||
{'ext': '_LT', 'def': ['TRANS', 'LOWER'], 'for': ['s', 'd', 'c', 'z']},
|
||||
{'ext': '_LN', 'def': ['LOWER'], 'undef': ['TRANS', 'CONJ'], 'for': ['s', 'd', 'c', 'z'], 'except': ['?trmm_kernel']},
|
||||
{'ext': '_LT', 'def': ['TRANS', 'LOWER'], 'for': ['s', 'd', 'c', 'z'], 'except': ['?trmm_kernel']},
|
||||
{'ext': '_RU', 'def': ['RSIDE', 'NC'], 'undef': ['LOWER'], 'for': ['c', 'z']},
|
||||
{'ext': '_RL', 'def': ['RSIDE', 'NC', 'LOWER'], 'for': ['c', 'z']},
|
||||
]
|
||||
|
@ -423,6 +427,7 @@ symb_defs = {
|
|||
'?her_thread': {'def': ['HER']},
|
||||
'?her2_thread': {'def': ['HER']},
|
||||
'?hpr_thread': {'def': ['HEMV']},
|
||||
'?trmm_kernel': {'def': ['TRMMKERNEL']},
|
||||
'?bgemm': {'def': ['HALF']},
|
||||
'cblas_?dotu_sub': {'def': ['CBLAS', 'FORCE_USE_STACK'], 'undef': ['CONJ']},
|
||||
'cblas_?dotc_sub': {'def': ['CBLAS', 'FORCE_USE_STACK', 'CONJ']},
|
||||
|
|
Loading…
Reference in New Issue