Merge pull request #14537 from taosdata/fix/TD-16877

fix(query): stddev function support unsigned data types
This commit is contained in:
Ganlin Zhao 2022-07-06 10:14:23 +08:00 committed by GitHub
commit 9355392c80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 122 additions and 33 deletions

View File

@ -90,12 +90,14 @@ typedef struct SStddevRes {
double result; double result;
int64_t count; int64_t count;
union { union {
double quadraticDSum; double quadraticDSum;
int64_t quadraticISum; int64_t quadraticISum;
uint64_t quadraticUSum;
}; };
union { union {
double dsum; double dsum;
int64_t isum; int64_t isum;
uint64_t usum;
}; };
int16_t type; int16_t type;
} SStddevRes; } SStddevRes;
@ -1729,6 +1731,68 @@ int32_t stddevFunction(SqlFunctionCtx* pCtx) {
break; break;
} }
case TSDB_DATA_TYPE_UTINYINT: {
uint8_t* plist = (uint8_t*)pCol->pData;
for (int32_t i = start; i < numOfRows + start; ++i) {
if (pCol->hasNull && colDataIsNull_f(pCol->nullbitmap, i)) {
continue;
}
numOfElem += 1;
pStddevRes->count += 1;
pStddevRes->usum += plist[i];
pStddevRes->quadraticISum += plist[i] * plist[i];
}
break;
}
case TSDB_DATA_TYPE_USMALLINT: {
uint16_t* plist = (uint16_t*)pCol->pData;
for (int32_t i = start; i < numOfRows + pInput->startRowIndex; ++i) {
if (pCol->hasNull && colDataIsNull_f(pCol->nullbitmap, i)) {
continue;
}
numOfElem += 1;
pStddevRes->count += 1;
pStddevRes->usum += plist[i];
pStddevRes->quadraticISum += plist[i] * plist[i];
}
break;
}
case TSDB_DATA_TYPE_UINT: {
uint32_t* plist = (uint32_t*)pCol->pData;
for (int32_t i = start; i < numOfRows + pInput->startRowIndex; ++i) {
if (pCol->hasNull && colDataIsNull_f(pCol->nullbitmap, i)) {
continue;
}
numOfElem += 1;
pStddevRes->count += 1;
pStddevRes->usum += plist[i];
pStddevRes->quadraticISum += plist[i] * plist[i];
}
break;
}
case TSDB_DATA_TYPE_UBIGINT: {
uint64_t* plist = (uint64_t*)pCol->pData;
for (int32_t i = start; i < numOfRows + pInput->startRowIndex; ++i) {
if (pCol->hasNull && colDataIsNull_f(pCol->nullbitmap, i)) {
continue;
}
numOfElem += 1;
pStddevRes->count += 1;
pStddevRes->usum += plist[i];
pStddevRes->quadraticISum += plist[i] * plist[i];
}
break;
}
case TSDB_DATA_TYPE_FLOAT: { case TSDB_DATA_TYPE_FLOAT: {
float* plist = (float*)pCol->pData; float* plist = (float*)pCol->pData;
for (int32_t i = start; i < numOfRows + pInput->startRowIndex; ++i) { for (int32_t i = start; i < numOfRows + pInput->startRowIndex; ++i) {
@ -1771,9 +1835,12 @@ _stddev_over:
static void stddevTransferInfo(SStddevRes* pInput, SStddevRes* pOutput) { static void stddevTransferInfo(SStddevRes* pInput, SStddevRes* pOutput) {
pOutput->type = pInput->type; pOutput->type = pInput->type;
if (IS_INTEGER_TYPE(pOutput->type)) { if (IS_SIGNED_NUMERIC_TYPE(pOutput->type)) {
pOutput->quadraticISum += pInput->quadraticISum; pOutput->quadraticISum += pInput->quadraticISum;
pOutput->isum += pInput->isum; pOutput->isum += pInput->isum;
} else if (IS_UNSIGNED_NUMERIC_TYPE(pOutput->type)) {
pOutput->quadraticUSum += pInput->quadraticUSum;
pOutput->usum += pInput->usum;
} else { } else {
pOutput->quadraticDSum += pInput->quadraticDSum; pOutput->quadraticDSum += pInput->quadraticDSum;
pOutput->dsum += pInput->dsum; pOutput->dsum += pInput->dsum;
@ -1848,6 +1915,22 @@ int32_t stddevInvertFunction(SqlFunctionCtx* pCtx) {
LIST_STDDEV_SUB_N(pStddevRes->isum, int64_t); LIST_STDDEV_SUB_N(pStddevRes->isum, int64_t);
break; break;
} }
case TSDB_DATA_TYPE_UTINYINT: {
LIST_STDDEV_SUB_N(pStddevRes->isum, uint8_t);
break;
}
case TSDB_DATA_TYPE_USMALLINT: {
LIST_STDDEV_SUB_N(pStddevRes->isum, uint16_t);
break;
}
case TSDB_DATA_TYPE_UINT: {
LIST_STDDEV_SUB_N(pStddevRes->isum, uint32_t);
break;
}
case TSDB_DATA_TYPE_UBIGINT: {
LIST_STDDEV_SUB_N(pStddevRes->isum, uint64_t);
break;
}
case TSDB_DATA_TYPE_FLOAT: { case TSDB_DATA_TYPE_FLOAT: {
LIST_STDDEV_SUB_N(pStddevRes->dsum, float); LIST_STDDEV_SUB_N(pStddevRes->dsum, float);
break; break;
@ -1871,9 +1954,12 @@ int32_t stddevFinalize(SqlFunctionCtx* pCtx, SSDataBlock* pBlock) {
int32_t type = pStddevRes->type; int32_t type = pStddevRes->type;
double avg; double avg;
if (IS_INTEGER_TYPE(type)) { if (IS_SIGNED_NUMERIC_TYPE(type)) {
avg = pStddevRes->isum / ((double)pStddevRes->count); avg = pStddevRes->isum / ((double)pStddevRes->count);
pStddevRes->result = sqrt(fabs(pStddevRes->quadraticISum / ((double)pStddevRes->count) - avg * avg)); pStddevRes->result = sqrt(fabs(pStddevRes->quadraticISum / ((double)pStddevRes->count) - avg * avg));
} else if (IS_UNSIGNED_NUMERIC_TYPE(type)) {
avg = pStddevRes->usum / ((double)pStddevRes->count);
pStddevRes->result = sqrt(fabs(pStddevRes->quadraticUSum / ((double)pStddevRes->count) - avg * avg));
} else { } else {
avg = pStddevRes->dsum / ((double)pStddevRes->count); avg = pStddevRes->dsum / ((double)pStddevRes->count);
pStddevRes->result = sqrt(fabs(pStddevRes->quadraticDSum / ((double)pStddevRes->count) - avg * avg)); pStddevRes->result = sqrt(fabs(pStddevRes->quadraticDSum / ((double)pStddevRes->count) - avg * avg));
@ -1913,9 +1999,12 @@ int32_t stddevCombine(SqlFunctionCtx* pDestCtx, SqlFunctionCtx* pSourceCtx) {
SResultRowEntryInfo* pSResInfo = GET_RES_INFO(pSourceCtx); SResultRowEntryInfo* pSResInfo = GET_RES_INFO(pSourceCtx);
SStddevRes* pSBuf = GET_ROWCELL_INTERBUF(pSResInfo); SStddevRes* pSBuf = GET_ROWCELL_INTERBUF(pSResInfo);
if (IS_INTEGER_TYPE(type)) { if (IS_SIGNED_NUMERIC_TYPE(type)) {
pDBuf->isum += pSBuf->isum; pDBuf->isum += pSBuf->isum;
pDBuf->quadraticISum += pSBuf->quadraticISum; pDBuf->quadraticISum += pSBuf->quadraticISum;
} else if (IS_UNSIGNED_NUMERIC_TYPE(type)) {
pDBuf->usum += pSBuf->usum;
pDBuf->quadraticUSum += pSBuf->quadraticUSum;
} else { } else {
pDBuf->dsum += pSBuf->dsum; pDBuf->dsum += pSBuf->dsum;
pDBuf->quadraticDSum += pSBuf->quadraticDSum; pDBuf->quadraticDSum += pSBuf->quadraticDSum;

View File

@ -7,7 +7,7 @@ import platform
import math import math
class TDTestCase: class TDTestCase:
updatecfgDict = {'debugFlag': 143 ,"cDebugFlag":143,"uDebugFlag":143 ,"rpcDebugFlag":143 , "tmrDebugFlag":143 , updatecfgDict = {'debugFlag': 143 ,"cDebugFlag":143,"uDebugFlag":143 ,"rpcDebugFlag":143 , "tmrDebugFlag":143 ,
"jniDebugFlag":143 ,"simDebugFlag":143,"dDebugFlag":143, "dDebugFlag":143,"vDebugFlag":143,"mDebugFlag":143,"qDebugFlag":143, "jniDebugFlag":143 ,"simDebugFlag":143,"dDebugFlag":143, "dDebugFlag":143,"vDebugFlag":143,"mDebugFlag":143,"qDebugFlag":143,
"wDebugFlag":143,"sDebugFlag":143,"tsdbDebugFlag":143,"tqDebugFlag":143 ,"fsDebugFlag":143 ,"fnDebugFlag":143, "wDebugFlag":143,"sDebugFlag":143,"tsdbDebugFlag":143,"tqDebugFlag":143 ,"fsDebugFlag":143 ,"fnDebugFlag":143,
"maxTablesPerVnode":2 ,"minTablesPerVnode":2,"tableIncStepPerVnode":2 } "maxTablesPerVnode":2 ,"minTablesPerVnode":2,"tableIncStepPerVnode":2 }
@ -24,7 +24,7 @@ class TDTestCase:
stddev_sql = f"select stddev({col_name}) from {tbname};" stddev_sql = f"select stddev({col_name}) from {tbname};"
same_sql = f"select {col_name} from {tbname} where {col_name} is not null " same_sql = f"select {col_name} from {tbname} where {col_name} is not null "
tdSql.query(same_sql) tdSql.query(same_sql)
pre_data = np.array(tdSql.queryResult)[np.array(tdSql.queryResult) != None] pre_data = np.array(tdSql.queryResult)[np.array(tdSql.queryResult) != None]
if (platform.system().lower() == 'windows' and pre_data.dtype == 'int32'): if (platform.system().lower() == 'windows' and pre_data.dtype == 'int32'):
@ -32,21 +32,21 @@ class TDTestCase:
pre_avg = np.sum(pre_data)/len(pre_data) pre_avg = np.sum(pre_data)/len(pre_data)
# Calculate variance # Calculate variance
stddev_result = 0 stddev_result = 0
for num in tdSql.queryResult: for num in tdSql.queryResult:
stddev_result += (num-pre_avg)*(num-pre_avg)/len(tdSql.queryResult) stddev_result += (num-pre_avg)*(num-pre_avg)/len(tdSql.queryResult)
stddev_result = math.sqrt(stddev_result) stddev_result = math.sqrt(stddev_result)
tdSql.query(stddev_sql) tdSql.query(stddev_sql)
if -0.0001 < tdSql.queryResult[0][0]-stddev_result < 0.0001: if -0.0001 < tdSql.queryResult[0][0]-stddev_result < 0.0001:
tdLog.info(" sql:%s; row:0 col:0 data:%d , expect:%d"%(stddev_sql,tdSql.queryResult[0][0],stddev_result)) tdLog.info(" sql:%s; row:0 col:0 data:%d , expect:%d"%(stddev_sql,tdSql.queryResult[0][0],stddev_result))
else: else:
tdLog.exit(" sql:%s; row:0 col:0 data:%d , expect:%d"%(stddev_sql,tdSql.queryResult[0][0],stddev_result)) tdLog.exit(" sql:%s; row:0 col:0 data:%d , expect:%d"%(stddev_sql,tdSql.queryResult[0][0],stddev_result))
def prepare_datas_of_distribute(self): def prepare_datas_of_distribute(self):
# prepate datas for 20 tables distributed at different vgroups # prepate datas for 20 tables distributed at different vgroups
tdSql.execute("create database if not exists testdb keep 3650 duration 1000 vgroups 5") tdSql.execute("create database if not exists testdb keep 3650 duration 1000 vgroups 5")
tdSql.execute(" use testdb ") tdSql.execute(" use testdb ")
@ -117,17 +117,17 @@ class TDTestCase:
vgroups = tdSql.queryResult vgroups = tdSql.queryResult
vnode_tables={} vnode_tables={}
for vgroup_id in vgroups: for vgroup_id in vgroups:
vnode_tables[vgroup_id[0]]=[] vnode_tables[vgroup_id[0]]=[]
# check sub_table of per vnode ,make sure sub_table has been distributed # check sub_table of per vnode ,make sure sub_table has been distributed
tdSql.query("show tables like 'ct%'") tdSql.query("show tables like 'ct%'")
table_names = tdSql.queryResult table_names = tdSql.queryResult
tablenames = [] tablenames = []
for table_name in table_names: for table_name in table_names:
vnode_tables[table_name[6]].append(table_name[0]) vnode_tables[table_name[6]].append(table_name[0])
self.vnode_disbutes = vnode_tables self.vnode_disbutes = vnode_tables
count = 0 count = 0
@ -138,14 +138,14 @@ class TDTestCase:
tdLog.exit(" the datas of all not satisfy sub_table has been distributed ") tdLog.exit(" the datas of all not satisfy sub_table has been distributed ")
def check_stddev_distribute_diff_vnode(self,col_name): def check_stddev_distribute_diff_vnode(self,col_name):
vgroup_ids = [] vgroup_ids = []
for k ,v in self.vnode_disbutes.items(): for k ,v in self.vnode_disbutes.items():
if len(v)>=2: if len(v)>=2:
vgroup_ids.append(k) vgroup_ids.append(k)
distribute_tbnames = [] distribute_tbnames = []
for vgroup_id in vgroup_ids: for vgroup_id in vgroup_ids:
vnode_tables = self.vnode_disbutes[vgroup_id] vnode_tables = self.vnode_disbutes[vgroup_id]
distribute_tbnames.append(random.sample(vnode_tables,1)[0]) distribute_tbnames.append(random.sample(vnode_tables,1)[0])
@ -154,7 +154,7 @@ class TDTestCase:
tbname_ins += "'%s' ,"%tbname tbname_ins += "'%s' ,"%tbname
tbname_filters = tbname_ins[:-1] tbname_filters = tbname_ins[:-1]
stddev_sql = f"select stddev({col_name}) from stb1 where tbname in ({tbname_filters});" stddev_sql = f"select stddev({col_name}) from stb1 where tbname in ({tbname_filters});"
same_sql = f"select {col_name} from stb1 where tbname in ({tbname_filters}) and {col_name} is not null " same_sql = f"select {col_name} from stb1 where tbname in ({tbname_filters}) and {col_name} is not null "
@ -166,7 +166,7 @@ class TDTestCase:
pre_avg = np.sum(pre_data)/len(pre_data) pre_avg = np.sum(pre_data)/len(pre_data)
# Calculate variance # Calculate variance
stddev_result = 0 stddev_result = 0
for num in tdSql.queryResult: for num in tdSql.queryResult:
stddev_result += (num-pre_avg)*(num-pre_avg)/len(tdSql.queryResult) stddev_result += (num-pre_avg)*(num-pre_avg)/len(tdSql.queryResult)
@ -177,8 +177,8 @@ class TDTestCase:
def check_stddev_status(self): def check_stddev_status(self):
# check max function work status # check max function work status
tdSql.query("show tables like 'ct%'") tdSql.query("show tables like 'ct%'")
table_names = tdSql.queryResult table_names = tdSql.queryResult
tablenames = [] tablenames = []
@ -187,31 +187,31 @@ class TDTestCase:
tdSql.query("desc stb1") tdSql.query("desc stb1")
col_names = tdSql.queryResult col_names = tdSql.queryResult
colnames = [] colnames = []
for col_name in col_names: for col_name in col_names:
if col_name[1] in ["INT" ,"BIGINT" ,"SMALLINT" ,"TINYINT" , "FLOAT" ,"DOUBLE"]: if col_name[1] in ["INT" ,"BIGINT" ,"SMALLINT" ,"TINYINT" , "FLOAT" ,"DOUBLE"]:
colnames.append(col_name[0]) colnames.append(col_name[0])
for tablename in tablenames: for tablename in tablenames:
for colname in colnames: for colname in colnames:
if colname.startswith("c"): if colname.startswith("c"):
self.check_stddev_functions(tablename,colname) self.check_stddev_functions(tablename,colname)
else: else:
# self.check_stddev_functions(tablename,colname) # self.check_stddev_functions(tablename,colname)
pass pass
# check max function for different vnode # check max function for different vnode
for colname in colnames: for colname in colnames:
if colname.startswith("c"): if colname.startswith("c"):
self.check_stddev_distribute_diff_vnode(colname) self.check_stddev_distribute_diff_vnode(colname)
else: else:
# self.check_stddev_distribute_diff_vnode(colname) # bug for tag # self.check_stddev_distribute_diff_vnode(colname) # bug for tag
pass pass
def distribute_agg_query(self): def distribute_agg_query(self):
# basic filter # basic filter
tdSql.query(" select stddev(c1) from stb1 ") tdSql.query(" select stddev(c1) from stb1 ")
@ -235,7 +235,7 @@ class TDTestCase:
tdSql.query("select stddev(c1) from stb1 where t1> 4 partition by tbname") tdSql.query("select stddev(c1) from stb1 where t1> 4 partition by tbname")
tdSql.checkRows(15) tdSql.checkRows(15)
# union all # union all
tdSql.query("select stddev(c1) from stb1 union all select stddev(c1) from stb1 ") tdSql.query("select stddev(c1) from stb1 union all select stddev(c1) from stb1 ")
tdSql.checkRows(2) tdSql.checkRows(2)
tdSql.checkData(0,0,6.694663959) tdSql.checkData(0,0,6.694663959)
@ -244,7 +244,7 @@ class TDTestCase:
tdSql.checkRows(1) tdSql.checkRows(1)
tdSql.checkData(0,0,0.000000000) tdSql.checkData(0,0,0.000000000)
# join # join
tdSql.execute(" create database if not exists db ") tdSql.execute(" create database if not exists db ")
tdSql.execute(" use db ") tdSql.execute(" use db ")
@ -252,7 +252,7 @@ class TDTestCase:
tdSql.execute(" create table tb1 using st tags(1) ") tdSql.execute(" create table tb1 using st tags(1) ")
tdSql.execute(" create table tb2 using st tags(2) ") tdSql.execute(" create table tb2 using st tags(2) ")
for i in range(10): for i in range(10):
ts = i*10 + self.ts ts = i*10 + self.ts
tdSql.execute(f" insert into tb1 values({ts},{i},{i}.0)") tdSql.execute(f" insert into tb1 values({ts},{i},{i}.0)")
@ -263,7 +263,7 @@ class TDTestCase:
tdSql.checkData(0,0,2.872281323) tdSql.checkData(0,0,2.872281323)
tdSql.checkData(0,1,2.872281323) tdSql.checkData(0,1,2.872281323)
# group by # group by
tdSql.execute(" use testdb ") tdSql.execute(" use testdb ")
# partition by tbname or partition by tag # partition by tbname or partition by tag
@ -295,7 +295,7 @@ class TDTestCase:
self.check_stddev_status() self.check_stddev_status()
self.distribute_agg_query() self.distribute_agg_query()
def stop(self): def stop(self):
tdSql.close() tdSql.close()
tdLog.success("%s successfully executed" % __file__) tdLog.success("%s successfully executed" % __file__)