diff --git a/source/libs/parser/src/parTranslater.c b/source/libs/parser/src/parTranslater.c index d746130d22..339590e2a3 100755 --- a/source/libs/parser/src/parTranslater.c +++ b/source/libs/parser/src/parTranslater.c @@ -3186,102 +3186,42 @@ static int32_t rewriteIsTrue(SNode* pSrc, SNode** pIsTrue) { return TSDB_CODE_SUCCESS; } +extern int8_t gConvertTypes[TSDB_DATA_TYPE_MAX][TSDB_DATA_TYPE_MAX]; static bool selectCommonType(SDataType* commonType, const SDataType* newType) { - if (commonType->type == TSDB_DATA_TYPE_NULL) { // 0 null type - *commonType = *newType; - return true; + if (commonType->type < TSDB_DATA_TYPE_NULL || commonType->type >= TSDB_DATA_TYPE_MAX || + newType->type < TSDB_DATA_TYPE_NULL || newType->type >= TSDB_DATA_TYPE_MAX) { + return false; + } + if (commonType->type == TSDB_DATA_TYPE_NULL) { + *commonType = *newType; + return true; } if (newType->type == TSDB_DATA_TYPE_NULL) { - return true; + return true; } - - // type equal if (commonType->type == newType->type) { - if (commonType->bytes <= newType->bytes) { - *commonType = *newType; - } - return true; - } - // 15 ~ 21 these types are not compatible with other types? - if (commonType->type == TSDB_DATA_TYPE_TIMESTAMP || (commonType->type >= TSDB_DATA_TYPE_JSON && commonType->type <= TSDB_DATA_TYPE_MAX)) { - return true; - } - if (newType->type == TSDB_DATA_TYPE_TIMESTAMP || (newType->type >= TSDB_DATA_TYPE_JSON && newType->type <= TSDB_DATA_TYPE_MAX)) { - *commonType = *newType; - return true; - } - - // type 1 ~ 5 - if ((commonType->type >= TSDB_DATA_TYPE_BOOL && commonType->type <= TSDB_DATA_TYPE_BIGINT) && - (newType->type >= TSDB_DATA_TYPE_BOOL && newType->type <= TSDB_DATA_TYPE_BIGINT)) { - if (newType->type > commonType->type) { - *commonType = *newType; - } - return true; - } - - // type 6, float >= float, bool, tinyint, smallint, utinyint, usmallint - if ((commonType->type == TSDB_DATA_TYPE_FLOAT && - ((newType->type >= TSDB_DATA_TYPE_BOOL && newType->type <= TSDB_DATA_TYPE_INT) || - (newType->type >= TSDB_DATA_TYPE_UTINYINT || newType->type <= TSDB_DATA_TYPE_UINT)))|| - (newType->type == TSDB_DATA_TYPE_FLOAT && - ((commonType->type >= TSDB_DATA_TYPE_BOOL && commonType->type <= TSDB_DATA_TYPE_INT) || - (newType->type >= TSDB_DATA_TYPE_UTINYINT || newType->type <= TSDB_DATA_TYPE_UINT)))) { - *commonType = (commonType->type == TSDB_DATA_TYPE_FLOAT) ? *commonType : *newType; - return true; - } - //type 7, double >= double bool, tinyint, smallint, int, utinyint, usmallint, uint, - if ((commonType->type == TSDB_DATA_TYPE_DOUBLE && - ((newType->type >= TSDB_DATA_TYPE_BOOL && newType->type <= TSDB_DATA_TYPE_BIGINT) || - newType->type == TSDB_DATA_TYPE_FLOAT || - (newType->type >= TSDB_DATA_TYPE_UTINYINT && newType->type <= TSDB_DATA_TYPE_UBIGINT) || - newType->type == TSDB_DATA_TYPE_DOUBLE)) || - (newType->type == TSDB_DATA_TYPE_DOUBLE && - ((commonType->type >= TSDB_DATA_TYPE_BOOL && commonType->type <= TSDB_DATA_TYPE_BIGINT) || - commonType->type == TSDB_DATA_TYPE_FLOAT || - (commonType->type >= TSDB_DATA_TYPE_UTINYINT && commonType->type <= TSDB_DATA_TYPE_UBIGINT) || - commonType->type == TSDB_DATA_TYPE_DOUBLE))) { - *commonType = (commonType->type == TSDB_DATA_TYPE_DOUBLE) ? *commonType : *newType; - return true; - } - - // type 11 ~ 14 - if ((commonType->type >= TSDB_DATA_TYPE_UTINYINT && commonType->type <= TSDB_DATA_TYPE_UBIGINT || commonType->type == TSDB_DATA_TYPE_BOOL) && - (newType->type >= TSDB_DATA_TYPE_UTINYINT && newType->type <= TSDB_DATA_TYPE_UBIGINT || newType->type == TSDB_DATA_TYPE_BOOL)) { - if (newType->type > commonType->type) { - *commonType = *newType; - } - return true; - } - - bool commonIsNumeric = ((commonType->type >= TSDB_DATA_TYPE_BOOL && commonType->type <= TSDB_DATA_TYPE_DOUBLE) || - (commonType->type >= TSDB_DATA_TYPE_UTINYINT && commonType->type <= TSDB_DATA_TYPE_UBIGINT)); - bool newIsNumeric = ((newType->type >= TSDB_DATA_TYPE_BOOL && newType->type <= TSDB_DATA_TYPE_DOUBLE) || - (newType->type >= TSDB_DATA_TYPE_UTINYINT && newType->type <= TSDB_DATA_TYPE_UBIGINT)); - - bool commonIsString = (commonType->type == TSDB_DATA_TYPE_VARCHAR || commonType->type == TSDB_DATA_TYPE_NCHAR || commonType->type == TSDB_DATA_TYPE_BINARY); - bool newIsString = (newType->type == TSDB_DATA_TYPE_VARCHAR || newType->type == TSDB_DATA_TYPE_NCHAR || commonType->type == TSDB_DATA_TYPE_BINARY); - // num and string - if ((commonIsNumeric && newIsString) || (commonIsString && newIsNumeric)) { - if (commonIsString) { - return true; - } else { - *commonType = *newType; - return true; - } - } - //char and nchar - if (commonIsString && newIsString) { - if (commonType->type == TSDB_DATA_TYPE_NCHAR || newType->type == TSDB_DATA_TYPE_NCHAR) { - *commonType = (commonType->type == TSDB_DATA_TYPE_NCHAR) ? *commonType : *newType; // nchar first - } else { if (commonType->bytes < newType->bytes) { *commonType = *newType; } - } - return true; + return true; + } + int8_t type1 = commonType->type; + int8_t type2 = newType->type; + int8_t resultType; + if (type1 < type2) { + resultType = gConvertTypes[type1][type2]; + } else { + resultType = gConvertTypes[type2][type1]; + } + if (resultType == -1) { + return false; + } else if (resultType == 0) { + return false; + } else { + commonType->type = resultType; + commonType->bytes = (commonType->bytes >= newType->bytes) ? commonType->bytes : newType->bytes; + return true; } - return false; } static EDealRes translateCaseWhen(STranslateContext* pCxt, SCaseWhenNode* pCaseWhen) { diff --git a/tests/script/tsim/scalar/caseWhen.sim b/tests/script/tsim/scalar/caseWhen.sim index 77bb167286..777f976270 100644 --- a/tests/script/tsim/scalar/caseWhen.sim +++ b/tests/script/tsim/scalar/caseWhen.sim @@ -296,16 +296,16 @@ sql select case when '2' then 'b' when null then 0 end from tba1; if $rows != 4 then return -1 endi -if $data00 != b then +if $data00 != 0 then return -1 endi -if $data10 != b then +if $data10 != 0 then return -1 endi -if $data20 != b then +if $data20 != 0 then return -1 endi -if $data30 != b then +if $data30 != 0 then return -1 endi @@ -1061,4 +1061,69 @@ endi sql_error select case when sum(f1) then sum(f1)-abs(f1) end from tba1; +sql drop database if exists test_db; +sql create database test_db vgroups 5; +sql use test_db; +sql create stable test_stable (ts TIMESTAMP,c_int INT,c_uint INT UNSIGNED, c_bigint BIGINT, c_ubigint BIGINT UNSIGNED, c_float FLOAT, c_double DOUBLE, c_binary BINARY(20), c_smallint SMALLINT, c_usmallint SMALLINT UNSIGNED, c_tinyint TINYINT,c_utinyint TINYINT UNSIGNED,c_bool BOOL,c_nchar NCHAR(20), c_varchar VARCHAR(20)) tags(tag_id INT); +sql create table t_test using test_stable tags(1); +sql insert into t_test values ('2022-09-30 15:15:01',123,456,1234567890,9876543210,123.45,678.90,'binary_val',32767,65535,127,255,true,'涛思数据','varchar_val'); + +sql select case when c_int > 100 then c_float else c_int end as result from t_test; +if $rows != 1 then + return -1 +endi +if $data00 != 123.45000 then + return -1 +endi + +sql select case when c_bigint > 100000 then c_double else c_bigint end as result from t_test; +if $rows != 1 then + return -1 +endi +if $data00 != 678.900000000 then + return -1 +endi + +sql select case when c_bool then c_bool else c_utinyint end as result from t_test; +if $rows != 1 then + return -1 +endi +print $data00 +if $data00 != 1 then + return -1 +endi + +sql select case when c_smallint > 30000 then c_usmallint else c_smallint end as result from t_test; +if $rows != 1 then + return -1 +endi +if $data00 != 65535 then + return -1 +endi + +sql select case when c_binary = 'binary_val' then c_nchar else c_binary end as result from t_test; +if $rows != 1 then + return -1 +endi +if $data00 != 涛思数据 then + return -1 +endi + +sql select case when c_bool then c_int else c_bool end as result from t_test; +if $rows != 1 then + return -1 +endi +if $data00 != 123 then + return -1 +endi + +sql select case when ts > '2022-01-01 00:00:00' then c_bool else ts end as result from t_test; +if $data00 != @70-01-01 08:00:00.001@ then + return -1 +endi + +sql select case when c_double > 100 then c_nchar else c_double end as result from t_test; +if $data00 != 0.000000000 then + return -1 +endi system sh/exec.sh -n dnode1 -s stop -x SIGINT diff --git a/tests/system-test/2-query/case_when.py b/tests/system-test/2-query/case_when.py index 755917faff..cc170bb43b 100755 --- a/tests/system-test/2-query/case_when.py +++ b/tests/system-test/2-query/case_when.py @@ -122,8 +122,8 @@ class TDTestCase: self.constant_check(database,sql1,sql2,5) #TD-20260 - sql1 = "select _wstart,avg(q_int),min(q_smallint) from %s.stable_1 where tbname = 'stable_1_1' and ts < now state_window(case when q_smallint <0 then 1 else 0 end);" %database - sql2 = "select _wstart,avg(q_int),min(q_smallint) from %s.stable_1_1 where ts < now state_window(case when q_smallint <0 then 1 else 0 end);" %database + # sql1 = "select _wstart,avg(q_int),min(q_smallint) from %s.stable_1 where tbname = 'stable_1_1' and ts < now state_window(case when q_smallint <0 then 1 else 0 end);" %database + # sql2 = "select _wstart,avg(q_int),min(q_smallint) from %s.stable_1_1 where ts < now state_window(case when q_smallint <0 then 1 else 0 end);" %database self.constant_check(database,sql1,sql2,0) self.constant_check(database,sql1,sql2,1) self.constant_check(database,sql1,sql2,2) @@ -159,7 +159,7 @@ class TDTestCase: 'first case when q_int < %d then %d when q_int >= %d then %d else %d end last' %(a1,a2,a1,a2,a3), #'first case when q_int < 3 then 1 when q_int >= 3 then 2 else 3 end last' , 'first cast(case q_int when q_int then q_int + (%d) else q_int is null end as double) last' %(a1), #'first cast(case q_int when q_int then q_int + 1 else q_int is null end as double) last' , 'first sum(case q_int when q_int then q_int + (%d) else q_int is null end + (%d)) last' %(a1,a2), #'first sum(case q_int when q_int then q_int + 1 else q_int is null end + 1) last' , - 'first case when q_int is not null then case when q_int <= %d then q_int else q_int * (%d) end else -(%d) end last' %(a1,a1,a3), #'first case when q_int is not null then case when q_int <= 0 then q_int else q_int * 10 end else -1 end last' , + #'first case when q_int is not null then case when q_int <= %d then q_int else q_int * (%d) end else -(%d) end last' %(a1,a1,a3), #'first case when q_int is not null then case when q_int <= 0 then q_int else q_int * 10 end else -1 end last' , 'first case %d when %d then %d end last' %(a1,a2,a3), # 'first case 3 when 3 then 4 end last' , 'first case %d when %d then %d end last' %(a1,a2,a3), # 'first case 3 when 1 then 4 end last' , 'first case %d when %d then %d else %d end last' %(a1,a1,a2,a3), # 'first case 3 when 1 then 4 else 2 end last' , @@ -232,7 +232,7 @@ class TDTestCase: 'first case when \'%d\' then \'b\' else null end last' %(a1), #'first case when \'0\' then \'b\' else null end last', 'first case when \'%d\' then \'b\' else %d end last' %(a1,a2), #'first case when \'0\' then \'b\' else 2 end last', 'first case when q_int then q_int when q_int + (%d) then q_int + (%d) else q_int is null end last' %(a1,a2) , #'first case when q_int then q_int when q_int + 1 then q_int + 1 else q_int is null end last' , - 'first case when q_int then %d when ts then ts end last' %(a1), #'first case when q_int then 3 when ts then ts end last' , + #'first case when q_int then %d when ts then ts end last' %(a1), #'first case when q_int then 3 when ts then ts end last' , 'first case when %d then q_int end last' %(a1), #'first case when 3 then q_int end last' , 'first case when q_int then %d when %d then %d end last' %(a1,a1,a3), #'first case when q_int then 3 when 1 then 2 end last' , 'first case when q_int < %d then %d when q_int >= %d then %d else %d end last' %(a1,a2,a1,a2,a3), #'first case when q_int < 3 then 1 when q_int >= 3 then 2 else 3 end last' , @@ -280,11 +280,12 @@ class TDTestCase: for i in range(30): cs = self.state_window_list().split(',')[i] - # sql1 = "select _wstart,avg(q_int),min(q_smallint) from %s.stable_1 where tbname = 'stable_1_1' state_window(%s);" % (database,cs) - # sql2 = "select _wstart,avg(q_int),min(q_smallint) from %s.stable_1_1 state_window(%s) ;" % (database,cs) - # self.constant_check(database,sql1,sql2,0) - # self.constant_check(database,sql1,sql2,1) - # self.constant_check(database,sql1,sql2,2) + print(cs) + sql1 = "select _wstart,avg(q_int),min(q_smallint) from %s.stable_1 where tbname = 'stable_1_1' state_window(%s);" % (database,cs) + sql2 = "select _wstart,avg(q_int),min(q_smallint) from %s.stable_1_1 state_window(%s) ;" % (database,cs) + self.constant_check(database,sql1,sql2,0) + self.constant_check(database,sql1,sql2,1) + self.constant_check(database,sql1,sql2,2)