test decimal agg funcs

This commit is contained in:
wangjiaming0909 2025-03-05 21:51:37 +08:00
parent 7a1ffd92ac
commit caccf4add9
12 changed files with 327 additions and 87 deletions

View File

@ -177,20 +177,21 @@ typedef struct SColumnDataAgg {
struct { struct {
uint64_t decimal128Sum[2]; uint64_t decimal128Sum[2];
uint64_t decimal128Max[2]; uint64_t decimal128Max[2];
uint64_t decimal128Min[2]; // TODO wjm 1. use deicmal128Sum for decimal64, 2. add overflow flag uint64_t decimal128Min[2]; // TODO wjm 1. use deicmal128Sum for decimal64, 2. add overflow flag
uint8_t overflow;
}; };
}; };
} SColumnDataAgg; } SColumnDataAgg;
#pragma pack(pop) #pragma pack(pop)
#define COL_AGG_GET_SUM_PTR(pAggs, dataType) \ #define COL_AGG_GET_SUM_PTR(pAggs, dataType) \
(dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->sum : (void*)pAggs->decimal128Sum) (!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->sum : (void*)pAggs->decimal128Sum)
#define COL_AGG_GET_MAX_PTR(pAggs, dataType) \ #define COL_AGG_GET_MAX_PTR(pAggs, dataType) \
(dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->max : (void*)pAggs->decimal128Max) (!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->max : (void*)pAggs->decimal128Max)
#define COL_AGG_GET_MIN_PTR(pAggs, dataType) \ #define COL_AGG_GET_MIN_PTR(pAggs, dataType) \
(dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->min : (void*)pAggs->decimal128Min) (!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->min : (void*)pAggs->decimal128Min)
typedef struct SBlockID { typedef struct SBlockID {
// The uid of table, from which current data block comes. And it is always 0, if current block is the // The uid of table, from which current data block comes. And it is always 0, if current block is the

View File

@ -95,6 +95,7 @@ bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const
int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pRightT, const SDataType* pOutT, int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pRightT, const SDataType* pOutT,
const void* pLeftData, const void* pRightData, void* pOutputData); const void* pLeftData, const void* pRightData, void* pOutputData);
int32_t convertToDecimal(const void* pData, const SDataType* pInputType, void* pOut, const SDataType* pOutType); int32_t convertToDecimal(const void* pData, const SDataType* pInputType, void* pOut, const SDataType* pOutType);
bool decimal128AddCheckOverflow(const Decimal128* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal64); DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal64);
DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal128); DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal128);

View File

@ -4222,29 +4222,36 @@ static FORCE_INLINE void tColDataCalcSMAVarType(SColData *pColData, SColumnDataA
} }
} }
#define CALC_DECIMAL_SUM_MAX_MIN(TYPE, pOps, pColData, pSum, pMax, pMin) \ #define CALC_DECIMAL_SUM_MAX_MIN(TYPE, pSumOp, pCompOp, pColData, pSum, pMax, pMin) \
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { \ do { \
pVal = ((TYPE *)pColData->pData) + iVal; \ if (decimal128AddCheckOverflow((Decimal *)pSum, pVal, WORD_NUM(TYPE))) *pOverflow = true; \
pOps->add(pSum, pVal, WORD_NUM(TYPE)); \ pSumOp->add(pSum, pVal, WORD_NUM(TYPE)); \
if (pOps->gt(pVal, pMax, WORD_NUM(TYPE))) { \ if (pCompOp->gt(pVal, pMax, WORD_NUM(TYPE))) { \
*(pMax) = *pVal; \ *(pMax) = *pVal; \
} \ } \
if (pOps->lt(pVal, pMin, WORD_NUM(TYPE))) { \ if (pCompOp->lt(pVal, pMin, WORD_NUM(TYPE))) { \
*(pMin) = *pVal; \ *(pMin) = *pVal; \
} \ } \
} } while (0)
static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColumnDataAgg* pAggs) { static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColumnDataAgg* pAggs) {
Decimal64* pSum = (Decimal64*)&pAggs->sum, *pMax = (Decimal64*)&pAggs->max, *pMin = (Decimal64*)&pAggs->min; Decimal128 *pSum = (Decimal128 *)pAggs->decimal128Sum;
*pSum = DECIMAL64_ZERO; Decimal64 *pMax = (Decimal64 *)pAggs->decimal128Max, *pMin = (Decimal64 *)pAggs->decimal128Min;
uint8_t *pOverflow = &pAggs->overflow;
*pSum = DECIMAL128_ZERO;
*pMax = DECIMAL64_MIN; *pMax = DECIMAL64_MIN;
*pMin = DECIMAL64_MAX; *pMin = DECIMAL64_MAX;
pAggs->numOfNull = 0; pAggs->numOfNull = 0;
pAggs->colId |= 0x80000000; // TODO wjm define it
Decimal64 *pVal = NULL; Decimal64 *pVal = NULL;
SDecimalOps *pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL64); const SDecimalOps *pSumOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
const SDecimalOps *pCompOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL64);
if (HAS_VALUE == pColData->flag) { if (HAS_VALUE == pColData->flag) {
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pOps, pColData, pSum, pMax, pMin); for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
pVal = ((Decimal64*)pColData->pData) + iVal;
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pSumOps, pCompOps, pColData, pSum, pMax, pMin);
}
} else { } else {
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
switch (tColDataGetBitValue(pColData, iVal)) { switch (tColDataGetBitValue(pColData, iVal)) {
@ -4253,7 +4260,8 @@ static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColum
pAggs->numOfNull++; pAggs->numOfNull++;
break; break;
case 2: case 2:
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pOps, pColData, pSum, pMax, pMin);// TODO wjm what if overflow pVal = ((Decimal64 *)pColData->pData) + iVal;
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pSumOps, pCompOps, pColData, pSum, pMax, pMin);
break; break;
default: default:
break; break;
@ -4263,7 +4271,9 @@ static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColum
} }
static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColumnDataAgg* pAggs) { static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColumnDataAgg* pAggs) {
Decimal128* pSum = (Decimal128*)pAggs->decimal128Sum, *pMax = (Decimal128*)pAggs->decimal128Max, *pMin = (Decimal128*)pAggs->decimal128Min; Decimal128 *pSum = (Decimal128 *)pAggs->decimal128Sum, *pMax = (Decimal128 *)pAggs->decimal128Max,
*pMin = (Decimal128 *)pAggs->decimal128Min;
uint8_t *pOverflow = &pAggs->overflow;
*pSum = DECIMAL128_ZERO; *pSum = DECIMAL128_ZERO;
*pMax = DECIMAL128_MIN; *pMax = DECIMAL128_MIN;
*pMin = DECIMAL128_MAX; *pMin = DECIMAL128_MAX;
@ -4273,7 +4283,10 @@ static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColu
Decimal128 *pVal = NULL; Decimal128 *pVal = NULL;
SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
if (HAS_VALUE == pColData->flag) { if (HAS_VALUE == pColData->flag) {
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pColData, pSum, pMax, pMin); for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
pVal = ((Decimal128*)pColData->pData) + iVal;
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pOps, pColData, pSum, pMax, pMin);
}
} else { } else {
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
switch (tColDataGetBitValue(pColData, iVal)) { switch (tColDataGetBitValue(pColData, iVal)) {
@ -4282,7 +4295,8 @@ static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColu
pAggs->numOfNull++; pAggs->numOfNull++;
break; break;
case 2: case 2:
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pColData, pSum, pMax, pMin); pVal = ((Decimal128*)pColData->pData) + iVal;
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pOps, pColData, pSum, pMax, pMin);
break; break;
default: default:
break; break;

View File

@ -1577,12 +1577,13 @@ int32_t tPutColumnDataAgg(SBuffer *buffer, SColumnDataAgg *pColAgg) {
if (pColAgg->colId & 0x80000000) { if (pColAgg->colId & 0x80000000) {
if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code; if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code;
if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code; if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Sum[0]))) return code; if ((code = tBufferPutU64(buffer, pColAgg->decimal128Sum[0]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Sum[1]))) return code; if ((code = tBufferPutU64(buffer, pColAgg->decimal128Sum[1]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Max[0]))) return code; if ((code = tBufferPutU64(buffer, pColAgg->decimal128Max[0]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Max[1]))) return code; if ((code = tBufferPutU64(buffer, pColAgg->decimal128Max[1]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Min[0]))) return code; if ((code = tBufferPutU64(buffer, pColAgg->decimal128Min[0]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Min[1]))) return code; if ((code = tBufferPutU64(buffer, pColAgg->decimal128Min[1]))) return code;
if ((code = tBufferPutU8(buffer, pColAgg->overflow))) return code;
} else { } else {
if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code; if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code;
if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code; if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code;
@ -1601,12 +1602,13 @@ int32_t tGetColumnDataAgg(SBufferReader *br, SColumnDataAgg *pColAgg) {
if ((code = tBufferGetI16v(br, &pColAgg->numOfNull))) return code; if ((code = tBufferGetI16v(br, &pColAgg->numOfNull))) return code;
if (pColAgg->colId & 0x80000000) { if (pColAgg->colId & 0x80000000) {
pColAgg->colId &= 0xFFFF; pColAgg->colId &= 0xFFFF;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Sum[0]))) return code; if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[0]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Sum[1]))) return code; if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[1]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Max[0]))) return code; if ((code = tBufferGetU64(br, &pColAgg->decimal128Max[0]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Max[1]))) return code; if ((code = tBufferGetU64(br, &pColAgg->decimal128Max[1]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Min[0]))) return code; if ((code = tBufferGetU64(br, &pColAgg->decimal128Min[0]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Min[1]))) return code; if ((code = tBufferGetU64(br, &pColAgg->decimal128Min[1]))) return code;
if ((code = tBufferGetU8(br, &pColAgg->overflow))) return code;
} else { } else {
if ((code = tBufferGetI64(br, &pColAgg->sum))) return code; if ((code = tBufferGetI64(br, &pColAgg->sum))) return code;
if ((code = tBufferGetI64(br, &pColAgg->max))) return code; if ((code = tBufferGetI64(br, &pColAgg->max))) return code;

View File

@ -1920,6 +1920,18 @@ static int32_t decimal128CountRoundingDelta(const Decimal128* pDec, int8_t scale
return res; return res;
} }
bool decimal128AddCheckOverflow(const Decimal128* pLeft, const DecimalType* pRight, uint8_t rightWordNum) {
if (DECIMAL128_SIGN(pLeft) == 0) {
Decimal128 max = decimal128Max;
decimal128Subtract(&max, pLeft, WORD_NUM(Decimal128));
return decimal128Lt(&max, pRight, rightWordNum);
} else {
Decimal128 min = decimal128Min;
decimal128Subtract(&min, pLeft, WORD_NUM(Decimal128));
return decimal128Gt(&min, pRight, rightWordNum);
}
}
int32_t TEST_decimal64From_int64_t(Decimal64* pDec, uint8_t prec, uint8_t scale, int64_t v) { int32_t TEST_decimal64From_int64_t(Decimal64* pDec, uint8_t prec, uint8_t scale, int64_t v) {
return decimal64FromInt64(pDec, prec, scale, v); return decimal64FromInt64(pDec, prec, scale, v);
} }

View File

@ -1509,6 +1509,17 @@ TEST_F(DecimalTest, decimalFromStr) {
Numeric<128> numeric128 = {38, 10, "0"}; Numeric<128> numeric128 = {38, 10, "0"};
} }
TEST(decimal, test_add_check_overflow) {
Numeric<128> dec128 = {38, 10, "9999999999999999999999999999.9999999999"};
Numeric<64> dec64 = {18, 2, "123.12"};
bool overflow = decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64));
ASSERT_TRUE(overflow);
dec128 = {38, 10, "-9999999999999999999999999999.9999999999"};
ASSERT_FALSE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64)));
dec64 = {18, 2, "-123.1"};
ASSERT_TRUE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64)));
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);

View File

@ -87,13 +87,19 @@ typedef struct SDecimalSumRes {
#define SUM_RES_GET_DECIMAL_SUM(pSumRes) ((SDecimalSumRes*)(pSumRes))->sum #define SUM_RES_GET_DECIMAL_SUM(pSumRes) ((SDecimalSumRes*)(pSumRes))->sum
// TODO wjm check for overflow // TODO wjm check for overflow
#define SUM_RES_INC_DECIMAL_SUM(pSumRes, pVal, type) \ #define SUM_RES_INC_DECIMAL_SUM(pSumRes, pVal, type) \
do { \ do { \
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \ const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
if (type == TSDB_DATA_TYPE_DECIMAL64) \ int32_t wordNum = 0; \
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal64)); \ if (type == TSDB_DATA_TYPE_DECIMAL64) { \
else \ wordNum = WORD_NUM(Decimal64); \
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal)); \ overflow = decimal128AddCheckOverflow(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \
} else { \
wordNum = WORD_NUM(Decimal); \
overflow = decimal128AddCheckOverflow(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \
} \
if (overflow) break; \
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \
} while (0) } while (0)
typedef struct SMinmaxResInfo { typedef struct SMinmaxResInfo {

View File

@ -106,17 +106,19 @@ typedef enum {
} \ } \
} while (0) } while (0)
#define LIST_ADD_DECIMAL_N(_res, _col, _start, _rows, _t, numOfElem) \ #define LIST_ADD_DECIMAL_N(_res, _col, _start, _rows, _t, numOfElem) \
do { \ do { \
_t* d = (_t*)(_col->pData); \ _t* d = (_t*)(_col->pData); \
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \ const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
for (int32_t i = (_start); i < (_rows) + (_start); ++i) { \ for (int32_t i = (_start); i < (_rows) + (_start); ++i) { \
if (((_col)->hasNull) && colDataIsNull_f((_col)->nullbitmap, i)) { \ if (((_col)->hasNull) && colDataIsNull_f((_col)->nullbitmap, i)) { \
continue; \ continue; \
}; \ }; \
pOps->add(_res, d + i, WORD_NUM(_t)); \ overflow = overflow || decimal128AddCheckOverflow((Decimal*)_res, d + i, WORD_NUM(_t)); \
(numOfElem)++; \ if (overflow) break; \
} \ pOps->add(_res, d + i, WORD_NUM(_t)); \
(numOfElem)++; \
} \
} while (0) } while (0)
#define LIST_SUB_N(_res, _col, _start, _rows, _t, numOfElem) \ #define LIST_SUB_N(_res, _col, _start, _rows, _t, numOfElem) \
@ -652,12 +654,12 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) {
SUM_RES_INC_DSUM(pSumRes, GET_DOUBLE_VAL((const char*)&(pAgg->sum))); SUM_RES_INC_DSUM(pSumRes, GET_DOUBLE_VAL((const char*)&(pAgg->sum)));
} else if (IS_DECIMAL_TYPE(type)) { } else if (IS_DECIMAL_TYPE(type)) {
SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL); SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL);
const SDecimalOps* pOps = getDecimalOps(type); const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
if (TSDB_DATA_TYPE_DECIMAL64 == type) { if (pAgg->overflow || decimal128AddCheckOverflow((Decimal*)&SUM_RES_GET_DECIMAL_SUM(pSumRes),
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->sum, WORD_NUM(Decimal64)); &pAgg->decimal128Sum, WORD_NUM(Decimal))) {
} else if (TSDB_DATA_TYPE_DECIMAL == type) { return TSDB_CODE_DECIMAL_OVERFLOW;
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->decimal128Sum, WORD_NUM(Decimal));
} }
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->decimal128Sum, WORD_NUM(Decimal));
} }
} else { // computing based on the true data block } else { // computing based on the true data block
SColumnInfoData* pCol = pInput->pData[0]; SColumnInfoData* pCol = pInput->pData[0];
@ -691,12 +693,13 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) {
LIST_ADD_N(SUM_RES_GET_DSUM(pSumRes), pCol, start, numOfRows, float, numOfElem); LIST_ADD_N(SUM_RES_GET_DSUM(pSumRes), pCol, start, numOfRows, float, numOfElem);
} else if (IS_DECIMAL_TYPE(type)) { } else if (IS_DECIMAL_TYPE(type)) {
SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL); SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL);
int32_t overflow = false;
if (TSDB_DATA_TYPE_DECIMAL64 == type) { if (TSDB_DATA_TYPE_DECIMAL64 == type) {
LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal64, numOfElem); LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal64, numOfElem);
} else if (TSDB_DATA_TYPE_DECIMAL == type) { } else if (TSDB_DATA_TYPE_DECIMAL == type) {
LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal128, numOfElem); LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal128, numOfElem);
} }
// TODO wjm check overflow if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
} }
} }

View File

@ -126,7 +126,7 @@ int32_t avgFunctionSetup(SqlFunctionCtx* pCtx, SResultRowEntryInfo* pResultInfo)
return TSDB_CODE_SUCCESS; return TSDB_CODE_SUCCESS;
} }
static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg) { static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg, int32_t* pNumOfElem) {
int32_t numOfElem = numOfRows - pAgg->numOfNull; int32_t numOfElem = numOfRows - pAgg->numOfNull;
AVG_RES_INC_COUNT(pRes, type, numOfElem); AVG_RES_INC_COUNT(pRes, type, numOfElem);
@ -137,16 +137,17 @@ static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type
} else if (IS_FLOAT_TYPE(type)) { } else if (IS_FLOAT_TYPE(type)) {
SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pRes), GET_DOUBLE_VAL((const char*)&(pAgg->sum))); SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pRes), GET_DOUBLE_VAL((const char*)&(pAgg->sum)));
} else if (IS_DECIMAL_TYPE(type)) { } else if (IS_DECIMAL_TYPE(type)) {
if (type == TSDB_DATA_TYPE_DECIMAL64) bool overflow = pAgg->overflow;
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->sum, TSDB_DATA_TYPE_DECIMAL64); if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
else SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->decimal128Sum, TSDB_DATA_TYPE_DECIMAL);
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->decimal128Sum, TSDB_DATA_TYPE_DECIMAL); if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
} }
return numOfElem; *pNumOfElem = numOfElem;
return 0;
} }
static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, void* pRes) { static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, void* pRes, int32_t* pNumOfElem) {
int32_t start = pInput->startRowIndex; int32_t start = pInput->startRowIndex;
int32_t numOfRows = pInput->numOfRows; int32_t numOfRows = pInput->numOfRows;
int32_t numOfElems = 0; int32_t numOfElems = 0;
@ -306,14 +307,17 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol
numOfElems += 1; numOfElems += 1;
AVG_RES_INC_COUNT(pRes, type, 1); AVG_RES_INC_COUNT(pRes, type, 1);
bool overflow = false;
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), (const void*)(pDec + i * tDataTypes[type].bytes), type); SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), (const void*)(pDec + i * tDataTypes[type].bytes), type);
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
} }
} break; } break;
default: default:
break; break;
} }
return numOfElems; *pNumOfElem = numOfElems;
return 0;
} }
int32_t avgFunction(SqlFunctionCtx* pCtx) { int32_t avgFunction(SqlFunctionCtx* pCtx) {
@ -341,7 +345,8 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) {
if (IS_DECIMAL_TYPE(type)) AVG_RES_SET_INPUT_SCALE(pAvgRes, pInput->pData[0]->info.scale); if (IS_DECIMAL_TYPE(type)) AVG_RES_SET_INPUT_SCALE(pAvgRes, pInput->pData[0]->info.scale);
if (pInput->colDataSMAIsSet) { // try to use SMA if available if (pInput->colDataSMAIsSet) { // try to use SMA if available
numOfElem = calculateAvgBySMAInfo(pAvgRes, numOfRows, type, pAgg); int32_t code = calculateAvgBySMAInfo(pAvgRes, numOfRows, type, pAgg, &numOfElem);
if (code != 0) return code;
} else if (!pCol->hasNull) { // try to employ the simd instructions to speed up the loop } else if (!pCol->hasNull) { // try to employ the simd instructions to speed up the loop
numOfElem = pInput->numOfRows; numOfElem = pInput->numOfRows;
AVG_RES_INC_COUNT(pAvgRes, pCtx->inputType, pInput->numOfRows); AVG_RES_INC_COUNT(pAvgRes, pCtx->inputType, pInput->numOfRows);
@ -424,6 +429,7 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) {
const char* pDec = pCol->pData; const char* pDec = pCol->pData;
// TODO wjm check for overflow // TODO wjm check for overflow
for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) { for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) {
bool overflow = false;
if (type == TSDB_DATA_TYPE_DECIMAL64) { if (type == TSDB_DATA_TYPE_DECIMAL64) {
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes), SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes),
TSDB_DATA_TYPE_DECIMAL64); TSDB_DATA_TYPE_DECIMAL64);
@ -431,13 +437,14 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) {
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes), SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes),
TSDB_DATA_TYPE_DECIMAL); TSDB_DATA_TYPE_DECIMAL);
} }
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
} }
} break; } break;
default: default:
return TSDB_CODE_FUNC_FUNTION_PARA_TYPE; return TSDB_CODE_FUNC_FUNTION_PARA_TYPE;
} }
} else { } else {
numOfElem = doAddNumericVector(pCol, type, pInput, pAvgRes); int32_t code = doAddNumericVector(pCol, type, pInput, pAvgRes, &numOfElem);
} }
_over: _over:
@ -446,12 +453,12 @@ _over:
return TSDB_CODE_SUCCESS; return TSDB_CODE_SUCCESS;
} }
static void avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) { static int32_t avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) {
int32_t inputDT = pCtx->pExpr->pExpr->_function.pFunctNode->srcFuncInputType.type; int32_t inputDT = pCtx->pExpr->pExpr->_function.pFunctNode->srcFuncInputType.type;
int32_t type = AVG_RES_GET_TYPE(pInput, inputDT); int32_t type = AVG_RES_GET_TYPE(pInput, inputDT);
pCtx->inputType = type; pCtx->inputType = type;
if (IS_NULL_TYPE(type)) { if (IS_NULL_TYPE(type)) {
return; return 0;
} }
@ -464,12 +471,15 @@ static void avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) {
CHECK_OVERFLOW_SUM_UNSIGNED_BIG(pOutput, (overflow ? SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)) : SUM_RES_GET_USUM(&AVG_RES_GET_SUM(pInput))), overflow); CHECK_OVERFLOW_SUM_UNSIGNED_BIG(pOutput, (overflow ? SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)) : SUM_RES_GET_USUM(&AVG_RES_GET_SUM(pInput))), overflow);
} else if (IS_DECIMAL_TYPE(type)) { } else if (IS_DECIMAL_TYPE(type)) {
AVG_RES_SET_INPUT_SCALE(pOutput, AVG_RES_GET_INPUT_SCALE(pInput)); AVG_RES_SET_INPUT_SCALE(pOutput, AVG_RES_GET_INPUT_SCALE(pInput));
bool overflow = false;
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pOutput), &AVG_RES_GET_DECIMAL_SUM(pInput), TSDB_DATA_TYPE_DECIMAL); SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pOutput), &AVG_RES_GET_DECIMAL_SUM(pInput), TSDB_DATA_TYPE_DECIMAL);
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
} else { } else {
SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pOutput), SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput))); SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pOutput), SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)));
} }
AVG_RES_INC_COUNT(pOutput, type, AVG_RES_GET_COUNT(pInput, true, type)); AVG_RES_INC_COUNT(pOutput, type, AVG_RES_GET_COUNT(pInput, true, type));
return 0;
} }
int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) { int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) {
@ -493,7 +503,8 @@ int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) {
if(colDataIsNull_s(pCol, i)) continue; if(colDataIsNull_s(pCol, i)) continue;
char* data = colDataGetData(pCol, i); char* data = colDataGetData(pCol, i);
void* pInputInfo = varDataVal(data); void* pInputInfo = varDataVal(data);
avgTransferInfo(pCtx, pInputInfo, pInfo); int32_t code = avgTransferInfo(pCtx, pInputInfo, pInfo);
if (code != 0) return code;
} }
SET_VAL(GET_RES_INFO(pCtx), 1, 1); SET_VAL(GET_RES_INFO(pCtx), 1, 1);

View File

@ -619,7 +619,7 @@ int32_t doMinMaxHelper(SqlFunctionCtx* pCtx, int32_t isMinFunc, int32_t* nElems)
int16_t index = 0; int16_t index = 0;
void* tval = NULL; void* tval = NULL;
if (type == TSDB_DATA_TYPE_DECIMAL) { if (IS_DECIMAL_TYPE(type)) {
tval = isMinFunc ? pInput->pColumnDataAgg[0]->decimal128Min : pInput->pColumnDataAgg[0]->decimal128Max; tval = isMinFunc ? pInput->pColumnDataAgg[0]->decimal128Min : pInput->pColumnDataAgg[0]->decimal128Max;
} else { } else {
tval = (isMinFunc) ? &pInput->pColumnDataAgg[0]->min : &pInput->pColumnDataAgg[0]->max; tval = (isMinFunc) ? &pInput->pColumnDataAgg[0]->min : &pInput->pColumnDataAgg[0]->max;

View File

@ -4228,7 +4228,7 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn
if (datum->kind == FLT_SCL_DATUM_KIND_DECIMAL64 || datum->kind == FLT_SCL_DATUM_KIND_DECIMAL) { if (datum->kind == FLT_SCL_DATUM_KIND_DECIMAL64 || datum->kind == FLT_SCL_DATUM_KIND_DECIMAL) {
int32_t code = convertToDecimal(pInput, &valDt, pData, &datum->type); int32_t code = convertToDecimal(pInput, &valDt, pData, &datum->type);
if (TSDB_CODE_SUCCESS != code) return code; // TODO wjm handle overflow error if (TSDB_CODE_SUCCESS != code) return code; // TODO wjm handle overflow error
valNode->node.resType = datum->type; //valNode->node.resType = datum->type;
} }
} }
FLT_RET(0); FLT_RET(0);

View File

@ -7,8 +7,10 @@ import time
import threading import threading
import secrets import secrets
import numpy import numpy
from paramiko import Agent
import query import query
from tag_lite import datatype
from util.log import * from util.log import *
from util.sql import * from util.sql import *
from util.cases import * from util.cases import *
@ -43,12 +45,13 @@ scalar_convert_err = -2147470768
decimal_insert_validator_test = False decimal_insert_validator_test = False
operator_test_round = 10 operator_test_round = 1
tb_insert_rows = 1000 tb_insert_rows = 1000
binary_op_with_const_test = True binary_op_with_const_test = True
binary_op_with_col_test = True binary_op_with_col_test = True
unary_op_test = True unary_op_test = True
binary_op_in_where_test = True binary_op_in_where_test = True
test_decimal_funcs = True
class DecimalTypeGeneratorConfig: class DecimalTypeGeneratorConfig:
def __init__(self): def __init__(self):
@ -156,14 +159,14 @@ class DecimalColumnAggregator:
self.null_num: int = 0 self.null_num: int = 0
self.none_num: int = 0 self.none_num: int = 0
def add_value(self, value: str): def add_value(self, value: str, scale: int):
self.count += 1 self.count += 1
if value == "NULL": if value == "NULL":
self.null_num += 1 self.null_num += 1
elif value == "None": elif value == "None":
self.none_num += 1 self.none_num += 1
else: else:
v: Decimal = Decimal(value) v: Decimal = get_decimal(value, scale)
self.sum += v self.sum += v
if v > self.max: if v > self.max:
self.max = v self.max = v
@ -245,6 +248,10 @@ class DecimalColumnExpr:
def should_skip_for_decimal(self, cols: list): def should_skip_for_decimal(self, cols: list):
return False return False
def check_query_results(self, query_col_res: List, tbname: str):
query_len = len(query_col_res)
pass
def check_for_filtering(self, query_col_res: List, tbname: str): def check_for_filtering(self, query_col_res: List, tbname: str):
j: int = -1 j: int = -1
for i in range(len(query_col_res)): for i in range(len(query_col_res)):
@ -479,6 +486,7 @@ class DataType:
return val return val
class DecimalType(DataType): class DecimalType(DataType):
MAX_PRECISION = 38
def __init__(self, type, precision: int, scale: int): def __init__(self, type, precision: int, scale: int):
self.precision_ = precision self.precision_ = precision
self.scale_ = scale self.scale_ = scale
@ -497,6 +505,14 @@ class DecimalType(DataType):
def get_decimal_type_mod(self) -> int: def get_decimal_type_mod(self) -> int:
return self.precision_ * 100 + self.scale() return self.precision_ * 100 + self.scale()
def set_prec(self, prec: int):
self.precision_ = prec
self.type_mod = self.get_decimal_type_mod()
def set_scale(self, scale: int):
self.scale_ = scale
self.type_mod = self.get_decimal_type_mod()
def prec(self): def prec(self):
return self.precision_ return self.precision_
@ -520,7 +536,7 @@ class DecimalType(DataType):
def generate_value(self) -> str: def generate_value(self) -> str:
val = self.decimal_generator.generate(self.generator_config) val = self.decimal_generator.generate(self.generator_config)
self.aggregator.add_value(val) ## convert to Decimal first self.aggregator.add_value(val, self.scale()) ## convert to Decimal first
# self.values.append(val) ## save it into files maybe # self.values.append(val) ## save it into files maybe
return val return val
@ -606,6 +622,12 @@ class Column:
return self.get_constant_val_for_execute() return self.get_constant_val_for_execute()
return self.get_typed_val_for_execute(self.saved_vals[tbname][idx]) return self.get_typed_val_for_execute(self.saved_vals[tbname][idx])
def get_cardinality(self, tbname):
if self.is_constant_col():
return 1
else:
return len(self.saved_vals[tbname])
## tbName: for normal table, pass the tbname, for child table, pass the child table name ## tbName: for normal table, pass the tbname, for child table, pass the child table name
def generate_value(self, tbName: str = '', save: bool = True): def generate_value(self, tbName: str = '', save: bool = True):
val = self.type_.generate_value() val = self.type_.generate_value()
@ -824,6 +846,143 @@ class TableDataValidator:
col.check(res[colIdx], row_num * self.tbIdx) col.check(res[colIdx], row_num * self.tbIdx)
colIdx += 1 colIdx += 1
class DecimalFunction(DecimalColumnExpr):
def __init__(self, format, executor, name: str):
super().__init__(format, executor)
self.func_name_ = name
def is_agg_func(self, op: str):
return False
def get_func_res(self):
return None
def get_input_types(self):
return [self.query_col]
@staticmethod
def get_decimal_agg_funcs() -> List:
return [DecimalMinFunction(), DecimalMaxFunction(), DecimalSumFunction(), DecimalAvgFunction()]
def check_results(self, query_col_res: List) -> bool:
return False
def check_for_agg_func(self, query_col_res: List, tbname: str):
col_expr = self.query_col
for i in range(col_expr.get_cardinality(tbname)):
col_val = col_expr.get_val_for_execute(tbname, i)
self.execute((col_val,))
if not self.check_results(query_col_res):
tdLog.exit(f"check failed for {self}, query got: {query_col_res}, expect {self.get_func_res()}")
class DecimalAggFunction(DecimalFunction):
def __init__(self, format, executor, name: str):
super().__init__(format, executor, name)
def is_agg_func(self, op):
return True
def should_skip_for_decimal(self, cols: list):
col: Column = cols[0]
if col.type_.is_decimal_type():
return False
return True
def check_results(self, query_col_res):
if len(query_col_res) == 0:
tdLog.info(f"query got no output: {self}, py calc: {self.get_func_res()}")
return True
else:
return self.get_func_res() == Decimal(query_col_res[0])
class DecimalMinFunction(DecimalAggFunction):
def __init__(self):
super().__init__("min({0})", DecimalMinFunction.execute_min, "min")
self.min_: Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.min
return self.min_
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
def execute_min(self, params):
if params[0] is None:
return
if self.min_ is None:
self.min_ = Decimal(params[0])
else:
self.min_ = min(self.min_, Decimal(params[0]))
return self.min_
class DecimalMaxFunction(DecimalAggFunction):
def __init__(self):
super().__init__("max({0})", DecimalMaxFunction.execute_max, "max")
self.max_: Decimal = None
def get_func_res(self) -> Decimal:
return self.max_
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
def execute_max(self, params):
if params[0] is None:
return
if self.max_ is None:
self.max_ = Decimal(params[0])
else:
self.max_ = max(self.max_, Decimal(params[0]))
return self.max_
class DecimalSumFunction(DecimalAggFunction):
def __init__(self):
super().__init__("sum({0})", DecimalSumFunction.execute_sum, "sum")
self.sum_:Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.sum
return self.sum_
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
self.res_type_.set_prec(DecimalType.MAX_PRECISION)
def execute_sum(self, params):
if params[0] is None:
return
if self.sum_ is None:
self.sum_ = Decimal(params[0])
else:
self.sum_ += Decimal(params[0])
return self.sum_
class DecimalAvgFunction(DecimalAggFunction):
def __init__(self):
super().__init__("avg({0})", DecimalAvgFunction.execute_avg, "avg")
self.count_: Decimal = 0
self.sum_: Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return get_decimal(
decimal_type.aggregator.sum
/ (decimal_type.aggregator.count - decimal_type.aggregator.null_num),
self.res_type_.scale(),
)
def generate_res_type(self) -> DataType:
sum_type = self.query_col.type_
sum_type.set_prec(DecimalType.MAX_PRECISION)
count_type = DataType(TypeEnum.BIGINT, 8, 0)
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(sum_type, count_type, "/")
def execute_avg(self, params):
if params[0] is None:
return
if self.sum_ is None:
self.sum_ = Decimal(params[0])
else:
self.sum_ += Decimal(params[0])
self.count_ += 1
return self.get_func_res()
class DecimalBinaryOperator(DecimalColumnExpr): class DecimalBinaryOperator(DecimalColumnExpr):
def __init__(self, format, executor, op: str): def __init__(self, format, executor, op: str):
@ -845,7 +1004,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
return True return True
if self.op_ != "%": if self.op_ != "%":
return False return False
## why skip decimal % float/double? it's wrong now. ## TODO wjm why skip decimal % float/double? it's wrong now.
left_is_real = left_col.type_.is_real_type() or left_col.type_.is_varchar_type() left_is_real = left_col.type_.is_real_type() or left_col.type_.is_varchar_type()
right_is_real = right_col.type_.is_real_type() or right_col.type_.is_varchar_type() right_is_real = right_col.type_.is_real_type() or right_col.type_.is_varchar_type()
if left_is_real or right_is_real: if left_is_real or right_is_real:
@ -1638,10 +1797,6 @@ class TDTestCase:
self.test_query_decimal_where_clause() self.test_query_decimal_where_clause()
def test_decimal_functions(self):
self.test_decimal_last_first_func()
funcs = ["max", "min", "sum", "avg", "count", "first", "last", "cast"]
def test_decimal_last_first_func(self): def test_decimal_last_first_func(self):
pass pass
@ -1663,9 +1818,7 @@ class TDTestCase:
for col in tb_cols: for col in tb_cols:
if col.name_ == '': if col.name_ == '':
continue continue
left_is_decimal = col.type_.is_decimal_type()
for const_col in constant_cols: for const_col in constant_cols:
right_is_decimal = const_col.type_.is_decimal_type()
if expr.should_skip_for_decimal([col, const_col]): if expr.should_skip_for_decimal([col, const_col]):
continue continue
const_col.generate_value() const_col.generate_value()
@ -1765,6 +1918,32 @@ class TDTestCase:
def test_query_decimal_case_when(self): def test_query_decimal_case_when(self):
pass pass
def test_decimal_agg_funcs(self, dbname, tbname, tb_cols: List[Column], get_agg_funcs_func):
agg_funcs: List[DecimalFunction] = get_agg_funcs_func()
for func in agg_funcs:
for col in tb_cols:
if col.name_ == '' or func.should_skip_for_decimal([col]):
continue
func.query_col = col
select_expr = func.generate([col])
sql = f"select {select_expr} from {dbname}.{tbname}"
res = TaosShell().query(sql)
if len(res) > 0:
res = res[0]
func.check_for_agg_func(res, tbname)
def test_decimal_functions(self):
if not test_decimal_funcs:
return
self.test_decimal_agg_funcs(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
DecimalFunction.get_decimal_agg_funcs,
)
self.test_decimal_last_first_func()
def test_query_decimal(self): def test_query_decimal(self):
self.test_decimal_operators() self.test_decimal_operators()
self.test_decimal_functions() self.test_decimal_functions()