ENH: Add TRMM_KERNEL bindings
This commit is contained in:
parent
854aecce82
commit
571d2f3be3
|
@ -450,8 +450,8 @@ endforeach
|
||||||
# Create the static libraries from the configurations
|
# Create the static libraries from the configurations
|
||||||
_kern_libs = []
|
_kern_libs = []
|
||||||
foreach conf : kernel_confs
|
foreach conf : kernel_confs
|
||||||
message(conf['name'])
|
# message(conf['name'])
|
||||||
message(conf)
|
# message(conf)
|
||||||
_kern_libs += [static_library(
|
_kern_libs += [static_library(
|
||||||
conf['name'],
|
conf['name'],
|
||||||
conf['src'],
|
conf['src'],
|
||||||
|
|
|
@ -357,16 +357,101 @@ base_kops = [
|
||||||
},
|
},
|
||||||
# Level 3 symbols
|
# Level 3 symbols
|
||||||
# TODO(rg): Handle the if defines set in arch
|
# TODO(rg): Handle the if defines set in arch
|
||||||
# { 'base': '?gemm_beta', # This is a bfloat16 symbol, skipping for now
|
{ 'base': '?gemm_beta',
|
||||||
# 'modes': {
|
'modes': {
|
||||||
# 's': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
|
's': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
|
||||||
# 'd': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
|
'd': {'exts': {'': {'dir': 'x86_64', 'kernel': 'gemm_beta.S'}}},
|
||||||
# 'c': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
|
'c': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
|
||||||
# 'z': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
|
'z': {'exts': {'': {'dir': 'x86_64', 'kernel': 'zgemm_beta.S'}}},
|
||||||
# # 'q': {'exts': {'': {'dir': 'generic', 'kernel': 'gemm_beta.c'}}},
|
# 'q': {'exts': {'': {'dir': 'generic', 'kernel': 'gemm_beta.c'}}},
|
||||||
# # 'x': {'exts': {'': {'dir': 'generic', 'kernel': 'zgemm_beta.c'}}},
|
# 'x': {'exts': {'': {'dir': 'generic', 'kernel': 'zgemm_beta.c'}}},
|
||||||
# },
|
},
|
||||||
# },
|
},
|
||||||
|
{ 'base': '?gemm_kernel',
|
||||||
|
'modes': {
|
||||||
|
's': {'exts': {'': {'dir': 'generic', 'kernel': 'gemmkernel_2x2.c'}}},
|
||||||
|
'd': {'exts': {'': {'dir': 'generic', 'kernel': 'gemmkernel_2x2.c'}}},
|
||||||
|
'c': {
|
||||||
|
'exts': {
|
||||||
|
'_n': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNN']},
|
||||||
|
'_l': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCN']},
|
||||||
|
# TODO(rg): What about _r conditionals? Makefile.L3:2969
|
||||||
|
'_r': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNC']},
|
||||||
|
'_b': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCC']},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'z': {
|
||||||
|
'exts': {
|
||||||
|
'_n': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNN']},
|
||||||
|
'_l': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCN']},
|
||||||
|
'_r': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DNC']},
|
||||||
|
'_b': {'dir': 'generic', 'kernel': 'zgemmkernel_2x2.c', 'addl': ['-DCC']},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
# 'q': {'exts': {'': {'dir': 'generic', 'kernel': 'gemm_beta.c'}}},
|
||||||
|
# 'x': {'exts': {'': {'dir': 'generic', 'kernel': 'zgemm_beta.c'}}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ 'base': '?trmm_kernel',
|
||||||
|
'modes': {
|
||||||
|
's': {
|
||||||
|
'exts': {
|
||||||
|
'_LN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||||
|
'_LT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c', 'addl': ['-DLEFT', '-DTRANSA']},
|
||||||
|
'_RN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||||
|
'_RT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'd': {
|
||||||
|
'exts': {
|
||||||
|
'_LN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||||
|
'_LT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c', 'addl': ['-DLEFT', '-DTRANSA']},
|
||||||
|
'_RN': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||||
|
'_RT': {'dir': 'generic', 'kernel': 'trmmkernel_2x2.c'},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'c': {
|
||||||
|
'exts': {
|
||||||
|
'_LN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-UCONJ', '-DNN']},
|
||||||
|
'_LT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DLEFT', '-DTRANSA', '-UCONJ', '-DNN']},
|
||||||
|
'_LR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DCONJ', '-DCN']},
|
||||||
|
'_LC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DCONJ', '-DCN']},
|
||||||
|
'_RN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-UCONJ', '-DNN']},
|
||||||
|
'_RT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-ULEFT', '-DTRANSA', '-UCONJ', '-DNN']},
|
||||||
|
'_RR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DCONJ', '-DNC']},
|
||||||
|
'_RC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DCONJ', '-DCN']},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'z': {
|
||||||
|
'exts': {
|
||||||
|
'_LN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-UCONJ', '-DNN']},
|
||||||
|
'_LT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DLEFT', '-DTRANSA', '-UCONJ', '-DNN']},
|
||||||
|
'_LR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DCONJ', '-DCN']},
|
||||||
|
'_LC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DCONJ', '-DCN']},
|
||||||
|
'_RN': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-UCONJ', '-DNN']},
|
||||||
|
'_RT': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-ULEFT', '-DTRANSA', '-UCONJ', '-DNN']},
|
||||||
|
'_RR': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DCONJ', '-DNC']},
|
||||||
|
'_RC': {'dir': 'generic', 'kernel': 'ztrmmkernel_2x2.c',
|
||||||
|
'addl': ['-DCONJ', '-DCN']},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
kernel_confs = []
|
kernel_confs = []
|
||||||
|
@ -394,9 +479,9 @@ foreach _kop : base_kops
|
||||||
prec_mode = precision_mappings[mode]
|
prec_mode = precision_mappings[mode]
|
||||||
# Generate the mapping for the type
|
# Generate the mapping for the type
|
||||||
if prec_mode.has_key('def')
|
if prec_mode.has_key('def')
|
||||||
foreach _d : prec_mode['def']
|
foreach _d : prec_mode['def']
|
||||||
__cargs += ('-D' + _d)
|
__cargs += ('-D' + _d)
|
||||||
endforeach
|
endforeach
|
||||||
endif
|
endif
|
||||||
if prec_mode.has_key('undef')
|
if prec_mode.has_key('undef')
|
||||||
foreach _u : prec_mode['undef']
|
foreach _u : prec_mode['undef']
|
||||||
|
@ -406,18 +491,38 @@ foreach _kop : base_kops
|
||||||
# Now the rest, one run for each ext, to get the final symbols
|
# Now the rest, one run for each ext, to get the final symbols
|
||||||
foreach ext, extdat : details['exts']
|
foreach ext, extdat : details['exts']
|
||||||
_ext_cargs = [] # Will be wiped for each ext preventing redefinitions
|
_ext_cargs = [] # Will be wiped for each ext preventing redefinitions
|
||||||
extmap = ext_mappings[ext]
|
# Check ext_mappings first
|
||||||
if extmap.has_key('def')
|
if ext_mappings.has_key(ext) and (not ext_mappings.has_key('except') or base not in ext_mappings['except'])
|
||||||
foreach _d : extmap['def']
|
extmap = ext_mappings[ext]
|
||||||
_ext_cargs += ('-D' + _d)
|
if extmap.has_key('def')
|
||||||
endforeach
|
foreach _d : extmap['def']
|
||||||
endif
|
_ext_cargs += ['-D' + _d]
|
||||||
if extmap.has_key('undef')
|
endforeach
|
||||||
foreach _u : extmap['undef']
|
endif
|
||||||
_ext_cargs += ('-U' + _u)
|
if extmap.has_key('undef')
|
||||||
endforeach
|
foreach _u : extmap['undef']
|
||||||
endif
|
_ext_cargs += ['-U' + _u]
|
||||||
# Construct the final paths
|
endforeach
|
||||||
|
endif
|
||||||
|
else
|
||||||
|
# Fallback to ext_mappings_l2
|
||||||
|
foreach ext_map : ext_mappings_l2 + ext_mappings_l3
|
||||||
|
if ext_map['ext'] == ext and mode in ext_map['for'] and (not ext_map.has_key('except') or base not in ext_map['except'])
|
||||||
|
if ext_map.has_key('def')
|
||||||
|
foreach _d : ext_map['def']
|
||||||
|
_ext_cargs += ['-D' + _d]
|
||||||
|
endforeach
|
||||||
|
endif
|
||||||
|
if ext_map.has_key('undef')
|
||||||
|
foreach _u : ext_map['undef']
|
||||||
|
_ext_cargs += ['-U' + _u]
|
||||||
|
endforeach
|
||||||
|
endif
|
||||||
|
break
|
||||||
|
endif
|
||||||
|
endforeach
|
||||||
|
endif
|
||||||
|
|
||||||
src = join_paths(extdat['dir'], extdat['kernel'])
|
src = join_paths(extdat['dir'], extdat['kernel'])
|
||||||
if extdat.has_key('addl')
|
if extdat.has_key('addl')
|
||||||
_ext_cargs += extdat['addl']
|
_ext_cargs += extdat['addl']
|
||||||
|
@ -444,8 +549,8 @@ endforeach
|
||||||
|
|
||||||
_kern_libs = []
|
_kern_libs = []
|
||||||
foreach conf: kernel_confs
|
foreach conf: kernel_confs
|
||||||
# message(conf['name'])
|
message(conf['name'])
|
||||||
# message(conf)
|
message(conf)
|
||||||
_kern_libs += static_library(
|
_kern_libs += static_library(
|
||||||
conf['name'],
|
conf['name'],
|
||||||
conf['src'],
|
conf['src'],
|
||||||
|
|
|
@ -303,6 +303,10 @@ ext_mappings = {
|
||||||
'_LL': {'def': ['LOWER', 'NN'], 'undef': ['RSIDE']},
|
'_LL': {'def': ['LOWER', 'NN'], 'undef': ['RSIDE']},
|
||||||
'_RU': {'def': ['RSIDE', 'NN'], 'undef': ['LOWER'], 'except': ['?hemm', '?hemm_thread']},
|
'_RU': {'def': ['RSIDE', 'NN'], 'undef': ['LOWER'], 'except': ['?hemm', '?hemm_thread']},
|
||||||
'_RL': {'def': ['RSIDE', 'NN', 'LOWER'], 'except': ['?hemm', '?hemm_thread']},
|
'_RL': {'def': ['RSIDE', 'NN', 'LOWER'], 'except': ['?hemm', '?hemm_thread']},
|
||||||
|
'_RN': {'undef': ['LEFT', 'TRANSA']},
|
||||||
|
'_RR': {'undef': ['LEFT', 'TRANSA']},
|
||||||
|
'_RT': {'def': ['TRANSA'], 'undef': ['LEFT']},
|
||||||
|
'_RC': {'def': ['TRANSA'], 'undef': ['LEFT']},
|
||||||
# TODO(rg): is CONJ OK for interface symbols?
|
# TODO(rg): is CONJ OK for interface symbols?
|
||||||
'_UN': {'undef': ['TRANS', 'LOWER', 'CONJ'], 'except': ['?syrk']},
|
'_UN': {'undef': ['TRANS', 'LOWER', 'CONJ'], 'except': ['?syrk']},
|
||||||
'_UT': {'def': ['TRANS'], 'undef': ['LOWER'], 'except': ['?syrk']},
|
'_UT': {'def': ['TRANS'], 'undef': ['LOWER'], 'except': ['?syrk']},
|
||||||
|
@ -393,8 +397,8 @@ ext_mappings_l3 = [
|
||||||
# syrk
|
# syrk
|
||||||
{'ext': '_UN', 'def': [], 'undef': ['LOWER', 'TRANS'], 'for': ['s', 'd', 'c', 'z']},
|
{'ext': '_UN', 'def': [], 'undef': ['LOWER', 'TRANS'], 'for': ['s', 'd', 'c', 'z']},
|
||||||
{'ext': '_UT', 'def': ['TRANS'], 'undef': ['LOWER'], 'for': ['s', 'd', 'c', 'z']},
|
{'ext': '_UT', 'def': ['TRANS'], 'undef': ['LOWER'], 'for': ['s', 'd', 'c', 'z']},
|
||||||
{'ext': '_LN', 'def': ['LOWER'], 'undef': ['TRANS', 'CONJ'], 'for': ['s', 'd', 'c', 'z']},
|
{'ext': '_LN', 'def': ['LOWER'], 'undef': ['TRANS', 'CONJ'], 'for': ['s', 'd', 'c', 'z'], 'except': ['?trmm_kernel']},
|
||||||
{'ext': '_LT', 'def': ['TRANS', 'LOWER'], 'for': ['s', 'd', 'c', 'z']},
|
{'ext': '_LT', 'def': ['TRANS', 'LOWER'], 'for': ['s', 'd', 'c', 'z'], 'except': ['?trmm_kernel']},
|
||||||
{'ext': '_RU', 'def': ['RSIDE', 'NC'], 'undef': ['LOWER'], 'for': ['c', 'z']},
|
{'ext': '_RU', 'def': ['RSIDE', 'NC'], 'undef': ['LOWER'], 'for': ['c', 'z']},
|
||||||
{'ext': '_RL', 'def': ['RSIDE', 'NC', 'LOWER'], 'for': ['c', 'z']},
|
{'ext': '_RL', 'def': ['RSIDE', 'NC', 'LOWER'], 'for': ['c', 'z']},
|
||||||
]
|
]
|
||||||
|
@ -423,6 +427,7 @@ symb_defs = {
|
||||||
'?her_thread': {'def': ['HER']},
|
'?her_thread': {'def': ['HER']},
|
||||||
'?her2_thread': {'def': ['HER']},
|
'?her2_thread': {'def': ['HER']},
|
||||||
'?hpr_thread': {'def': ['HEMV']},
|
'?hpr_thread': {'def': ['HEMV']},
|
||||||
|
'?trmm_kernel': {'def': ['TRMMKERNEL']},
|
||||||
'?bgemm': {'def': ['HALF']},
|
'?bgemm': {'def': ['HALF']},
|
||||||
'cblas_?dotu_sub': {'def': ['CBLAS', 'FORCE_USE_STACK'], 'undef': ['CONJ']},
|
'cblas_?dotu_sub': {'def': ['CBLAS', 'FORCE_USE_STACK'], 'undef': ['CONJ']},
|
||||||
'cblas_?dotc_sub': {'def': ['CBLAS', 'FORCE_USE_STACK', 'CONJ']},
|
'cblas_?dotc_sub': {'def': ['CBLAS', 'FORCE_USE_STACK', 'CONJ']},
|
||||||
|
|
Loading…
Reference in New Issue