test decimal agg funcs

This commit is contained in:
wangjiaming0909 2025-03-05 21:51:37 +08:00
parent 7a1ffd92ac
commit caccf4add9
12 changed files with 327 additions and 87 deletions

View File

@ -177,20 +177,21 @@ typedef struct SColumnDataAgg {
struct {
uint64_t decimal128Sum[2];
uint64_t decimal128Max[2];
uint64_t decimal128Min[2]; // TODO wjm 1. use deicmal128Sum for decimal64, 2. add overflow flag
uint64_t decimal128Min[2]; // TODO wjm 1. use deicmal128Sum for decimal64, 2. add overflow flag
uint8_t overflow;
};
};
} SColumnDataAgg;
#pragma pack(pop)
#define COL_AGG_GET_SUM_PTR(pAggs, dataType) \
(dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->sum : (void*)pAggs->decimal128Sum)
(!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->sum : (void*)pAggs->decimal128Sum)
#define COL_AGG_GET_MAX_PTR(pAggs, dataType) \
(dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->max : (void*)pAggs->decimal128Max)
(!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->max : (void*)pAggs->decimal128Max)
#define COL_AGG_GET_MIN_PTR(pAggs, dataType) \
(dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->min : (void*)pAggs->decimal128Min)
(!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->min : (void*)pAggs->decimal128Min)
typedef struct SBlockID {
// The uid of table, from which current data block comes. And it is always 0, if current block is the

View File

@ -95,6 +95,7 @@ bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const
int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pRightT, const SDataType* pOutT,
const void* pLeftData, const void* pRightData, void* pOutputData);
int32_t convertToDecimal(const void* pData, const SDataType* pInputType, void* pOut, const SDataType* pOutType);
bool decimal128AddCheckOverflow(const Decimal128* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal64);
DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal128);

View File

@ -4222,29 +4222,36 @@ static FORCE_INLINE void tColDataCalcSMAVarType(SColData *pColData, SColumnDataA
}
}
#define CALC_DECIMAL_SUM_MAX_MIN(TYPE, pOps, pColData, pSum, pMax, pMin) \
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { \
pVal = ((TYPE *)pColData->pData) + iVal; \
pOps->add(pSum, pVal, WORD_NUM(TYPE)); \
if (pOps->gt(pVal, pMax, WORD_NUM(TYPE))) { \
*(pMax) = *pVal; \
} \
if (pOps->lt(pVal, pMin, WORD_NUM(TYPE))) { \
*(pMin) = *pVal; \
} \
}
#define CALC_DECIMAL_SUM_MAX_MIN(TYPE, pSumOp, pCompOp, pColData, pSum, pMax, pMin) \
do { \
if (decimal128AddCheckOverflow((Decimal *)pSum, pVal, WORD_NUM(TYPE))) *pOverflow = true; \
pSumOp->add(pSum, pVal, WORD_NUM(TYPE)); \
if (pCompOp->gt(pVal, pMax, WORD_NUM(TYPE))) { \
*(pMax) = *pVal; \
} \
if (pCompOp->lt(pVal, pMin, WORD_NUM(TYPE))) { \
*(pMin) = *pVal; \
} \
} while (0)
static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColumnDataAgg* pAggs) {
Decimal64* pSum = (Decimal64*)&pAggs->sum, *pMax = (Decimal64*)&pAggs->max, *pMin = (Decimal64*)&pAggs->min;
*pSum = DECIMAL64_ZERO;
Decimal128 *pSum = (Decimal128 *)pAggs->decimal128Sum;
Decimal64 *pMax = (Decimal64 *)pAggs->decimal128Max, *pMin = (Decimal64 *)pAggs->decimal128Min;
uint8_t *pOverflow = &pAggs->overflow;
*pSum = DECIMAL128_ZERO;
*pMax = DECIMAL64_MIN;
*pMin = DECIMAL64_MAX;
pAggs->numOfNull = 0;
pAggs->colId |= 0x80000000; // TODO wjm define it
Decimal64 *pVal = NULL;
SDecimalOps *pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL64);
const SDecimalOps *pSumOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
const SDecimalOps *pCompOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL64);
if (HAS_VALUE == pColData->flag) {
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pOps, pColData, pSum, pMax, pMin);
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
pVal = ((Decimal64*)pColData->pData) + iVal;
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pSumOps, pCompOps, pColData, pSum, pMax, pMin);
}
} else {
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
switch (tColDataGetBitValue(pColData, iVal)) {
@ -4253,7 +4260,8 @@ static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColum
pAggs->numOfNull++;
break;
case 2:
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pOps, pColData, pSum, pMax, pMin);// TODO wjm what if overflow
pVal = ((Decimal64 *)pColData->pData) + iVal;
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pSumOps, pCompOps, pColData, pSum, pMax, pMin);
break;
default:
break;
@ -4263,7 +4271,9 @@ static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColum
}
static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColumnDataAgg* pAggs) {
Decimal128* pSum = (Decimal128*)pAggs->decimal128Sum, *pMax = (Decimal128*)pAggs->decimal128Max, *pMin = (Decimal128*)pAggs->decimal128Min;
Decimal128 *pSum = (Decimal128 *)pAggs->decimal128Sum, *pMax = (Decimal128 *)pAggs->decimal128Max,
*pMin = (Decimal128 *)pAggs->decimal128Min;
uint8_t *pOverflow = &pAggs->overflow;
*pSum = DECIMAL128_ZERO;
*pMax = DECIMAL128_MIN;
*pMin = DECIMAL128_MAX;
@ -4273,7 +4283,10 @@ static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColu
Decimal128 *pVal = NULL;
SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
if (HAS_VALUE == pColData->flag) {
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pColData, pSum, pMax, pMin);
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
pVal = ((Decimal128*)pColData->pData) + iVal;
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pOps, pColData, pSum, pMax, pMin);
}
} else {
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
switch (tColDataGetBitValue(pColData, iVal)) {
@ -4282,7 +4295,8 @@ static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColu
pAggs->numOfNull++;
break;
case 2:
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pColData, pSum, pMax, pMin);
pVal = ((Decimal128*)pColData->pData) + iVal;
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pOps, pColData, pSum, pMax, pMin);
break;
default:
break;

View File

@ -1577,12 +1577,13 @@ int32_t tPutColumnDataAgg(SBuffer *buffer, SColumnDataAgg *pColAgg) {
if (pColAgg->colId & 0x80000000) {
if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code;
if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Sum[0]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Sum[1]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Max[0]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Max[1]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Min[0]))) return code;
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Min[1]))) return code;
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Sum[0]))) return code;
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Sum[1]))) return code;
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Max[0]))) return code;
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Max[1]))) return code;
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Min[0]))) return code;
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Min[1]))) return code;
if ((code = tBufferPutU8(buffer, pColAgg->overflow))) return code;
} else {
if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code;
if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code;
@ -1601,12 +1602,13 @@ int32_t tGetColumnDataAgg(SBufferReader *br, SColumnDataAgg *pColAgg) {
if ((code = tBufferGetI16v(br, &pColAgg->numOfNull))) return code;
if (pColAgg->colId & 0x80000000) {
pColAgg->colId &= 0xFFFF;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Sum[0]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Sum[1]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Max[0]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Max[1]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Min[0]))) return code;
if ((code = tBufferGetI64(br, &pColAgg->decimal128Min[1]))) return code;
if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[0]))) return code;
if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[1]))) return code;
if ((code = tBufferGetU64(br, &pColAgg->decimal128Max[0]))) return code;
if ((code = tBufferGetU64(br, &pColAgg->decimal128Max[1]))) return code;
if ((code = tBufferGetU64(br, &pColAgg->decimal128Min[0]))) return code;
if ((code = tBufferGetU64(br, &pColAgg->decimal128Min[1]))) return code;
if ((code = tBufferGetU8(br, &pColAgg->overflow))) return code;
} else {
if ((code = tBufferGetI64(br, &pColAgg->sum))) return code;
if ((code = tBufferGetI64(br, &pColAgg->max))) return code;

View File

@ -1920,6 +1920,18 @@ static int32_t decimal128CountRoundingDelta(const Decimal128* pDec, int8_t scale
return res;
}
bool decimal128AddCheckOverflow(const Decimal128* pLeft, const DecimalType* pRight, uint8_t rightWordNum) {
if (DECIMAL128_SIGN(pLeft) == 0) {
Decimal128 max = decimal128Max;
decimal128Subtract(&max, pLeft, WORD_NUM(Decimal128));
return decimal128Lt(&max, pRight, rightWordNum);
} else {
Decimal128 min = decimal128Min;
decimal128Subtract(&min, pLeft, WORD_NUM(Decimal128));
return decimal128Gt(&min, pRight, rightWordNum);
}
}
int32_t TEST_decimal64From_int64_t(Decimal64* pDec, uint8_t prec, uint8_t scale, int64_t v) {
return decimal64FromInt64(pDec, prec, scale, v);
}

View File

@ -1509,6 +1509,17 @@ TEST_F(DecimalTest, decimalFromStr) {
Numeric<128> numeric128 = {38, 10, "0"};
}
TEST(decimal, test_add_check_overflow) {
Numeric<128> dec128 = {38, 10, "9999999999999999999999999999.9999999999"};
Numeric<64> dec64 = {18, 2, "123.12"};
bool overflow = decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64));
ASSERT_TRUE(overflow);
dec128 = {38, 10, "-9999999999999999999999999999.9999999999"};
ASSERT_FALSE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64)));
dec64 = {18, 2, "-123.1"};
ASSERT_TRUE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64)));
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);

View File

@ -87,13 +87,19 @@ typedef struct SDecimalSumRes {
#define SUM_RES_GET_DECIMAL_SUM(pSumRes) ((SDecimalSumRes*)(pSumRes))->sum
// TODO wjm check for overflow
#define SUM_RES_INC_DECIMAL_SUM(pSumRes, pVal, type) \
do { \
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
if (type == TSDB_DATA_TYPE_DECIMAL64) \
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal64)); \
else \
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal)); \
#define SUM_RES_INC_DECIMAL_SUM(pSumRes, pVal, type) \
do { \
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
int32_t wordNum = 0; \
if (type == TSDB_DATA_TYPE_DECIMAL64) { \
wordNum = WORD_NUM(Decimal64); \
overflow = decimal128AddCheckOverflow(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \
} else { \
wordNum = WORD_NUM(Decimal); \
overflow = decimal128AddCheckOverflow(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \
} \
if (overflow) break; \
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \
} while (0)
typedef struct SMinmaxResInfo {

View File

@ -106,17 +106,19 @@ typedef enum {
} \
} while (0)
#define LIST_ADD_DECIMAL_N(_res, _col, _start, _rows, _t, numOfElem) \
do { \
_t* d = (_t*)(_col->pData); \
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
for (int32_t i = (_start); i < (_rows) + (_start); ++i) { \
if (((_col)->hasNull) && colDataIsNull_f((_col)->nullbitmap, i)) { \
continue; \
}; \
pOps->add(_res, d + i, WORD_NUM(_t)); \
(numOfElem)++; \
} \
#define LIST_ADD_DECIMAL_N(_res, _col, _start, _rows, _t, numOfElem) \
do { \
_t* d = (_t*)(_col->pData); \
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
for (int32_t i = (_start); i < (_rows) + (_start); ++i) { \
if (((_col)->hasNull) && colDataIsNull_f((_col)->nullbitmap, i)) { \
continue; \
}; \
overflow = overflow || decimal128AddCheckOverflow((Decimal*)_res, d + i, WORD_NUM(_t)); \
if (overflow) break; \
pOps->add(_res, d + i, WORD_NUM(_t)); \
(numOfElem)++; \
} \
} while (0)
#define LIST_SUB_N(_res, _col, _start, _rows, _t, numOfElem) \
@ -652,12 +654,12 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) {
SUM_RES_INC_DSUM(pSumRes, GET_DOUBLE_VAL((const char*)&(pAgg->sum)));
} else if (IS_DECIMAL_TYPE(type)) {
SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL);
const SDecimalOps* pOps = getDecimalOps(type);
if (TSDB_DATA_TYPE_DECIMAL64 == type) {
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->sum, WORD_NUM(Decimal64));
} else if (TSDB_DATA_TYPE_DECIMAL == type) {
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->decimal128Sum, WORD_NUM(Decimal));
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
if (pAgg->overflow || decimal128AddCheckOverflow((Decimal*)&SUM_RES_GET_DECIMAL_SUM(pSumRes),
&pAgg->decimal128Sum, WORD_NUM(Decimal))) {
return TSDB_CODE_DECIMAL_OVERFLOW;
}
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->decimal128Sum, WORD_NUM(Decimal));
}
} else { // computing based on the true data block
SColumnInfoData* pCol = pInput->pData[0];
@ -691,12 +693,13 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) {
LIST_ADD_N(SUM_RES_GET_DSUM(pSumRes), pCol, start, numOfRows, float, numOfElem);
} else if (IS_DECIMAL_TYPE(type)) {
SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL);
int32_t overflow = false;
if (TSDB_DATA_TYPE_DECIMAL64 == type) {
LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal64, numOfElem);
} else if (TSDB_DATA_TYPE_DECIMAL == type) {
LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal128, numOfElem);
}
// TODO wjm check overflow
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
}
}

View File

@ -126,7 +126,7 @@ int32_t avgFunctionSetup(SqlFunctionCtx* pCtx, SResultRowEntryInfo* pResultInfo)
return TSDB_CODE_SUCCESS;
}
static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg) {
static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg, int32_t* pNumOfElem) {
int32_t numOfElem = numOfRows - pAgg->numOfNull;
AVG_RES_INC_COUNT(pRes, type, numOfElem);
@ -137,16 +137,17 @@ static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type
} else if (IS_FLOAT_TYPE(type)) {
SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pRes), GET_DOUBLE_VAL((const char*)&(pAgg->sum)));
} else if (IS_DECIMAL_TYPE(type)) {
if (type == TSDB_DATA_TYPE_DECIMAL64)
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->sum, TSDB_DATA_TYPE_DECIMAL64);
else
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->decimal128Sum, TSDB_DATA_TYPE_DECIMAL);
bool overflow = pAgg->overflow;
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->decimal128Sum, TSDB_DATA_TYPE_DECIMAL);
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
}
return numOfElem;
*pNumOfElem = numOfElem;
return 0;
}
static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, void* pRes) {
static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, void* pRes, int32_t* pNumOfElem) {
int32_t start = pInput->startRowIndex;
int32_t numOfRows = pInput->numOfRows;
int32_t numOfElems = 0;
@ -306,14 +307,17 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol
numOfElems += 1;
AVG_RES_INC_COUNT(pRes, type, 1);
bool overflow = false;
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), (const void*)(pDec + i * tDataTypes[type].bytes), type);
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
}
} break;
default:
break;
}
return numOfElems;
*pNumOfElem = numOfElems;
return 0;
}
int32_t avgFunction(SqlFunctionCtx* pCtx) {
@ -341,7 +345,8 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) {
if (IS_DECIMAL_TYPE(type)) AVG_RES_SET_INPUT_SCALE(pAvgRes, pInput->pData[0]->info.scale);
if (pInput->colDataSMAIsSet) { // try to use SMA if available
numOfElem = calculateAvgBySMAInfo(pAvgRes, numOfRows, type, pAgg);
int32_t code = calculateAvgBySMAInfo(pAvgRes, numOfRows, type, pAgg, &numOfElem);
if (code != 0) return code;
} else if (!pCol->hasNull) { // try to employ the simd instructions to speed up the loop
numOfElem = pInput->numOfRows;
AVG_RES_INC_COUNT(pAvgRes, pCtx->inputType, pInput->numOfRows);
@ -424,6 +429,7 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) {
const char* pDec = pCol->pData;
// TODO wjm check for overflow
for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) {
bool overflow = false;
if (type == TSDB_DATA_TYPE_DECIMAL64) {
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes),
TSDB_DATA_TYPE_DECIMAL64);
@ -431,13 +437,14 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) {
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes),
TSDB_DATA_TYPE_DECIMAL);
}
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
}
} break;
default:
return TSDB_CODE_FUNC_FUNTION_PARA_TYPE;
}
} else {
numOfElem = doAddNumericVector(pCol, type, pInput, pAvgRes);
int32_t code = doAddNumericVector(pCol, type, pInput, pAvgRes, &numOfElem);
}
_over:
@ -446,12 +453,12 @@ _over:
return TSDB_CODE_SUCCESS;
}
static void avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) {
static int32_t avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) {
int32_t inputDT = pCtx->pExpr->pExpr->_function.pFunctNode->srcFuncInputType.type;
int32_t type = AVG_RES_GET_TYPE(pInput, inputDT);
pCtx->inputType = type;
if (IS_NULL_TYPE(type)) {
return;
return 0;
}
@ -464,12 +471,15 @@ static void avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) {
CHECK_OVERFLOW_SUM_UNSIGNED_BIG(pOutput, (overflow ? SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)) : SUM_RES_GET_USUM(&AVG_RES_GET_SUM(pInput))), overflow);
} else if (IS_DECIMAL_TYPE(type)) {
AVG_RES_SET_INPUT_SCALE(pOutput, AVG_RES_GET_INPUT_SCALE(pInput));
bool overflow = false;
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pOutput), &AVG_RES_GET_DECIMAL_SUM(pInput), TSDB_DATA_TYPE_DECIMAL);
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
} else {
SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pOutput), SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)));
}
AVG_RES_INC_COUNT(pOutput, type, AVG_RES_GET_COUNT(pInput, true, type));
return 0;
}
int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) {
@ -493,7 +503,8 @@ int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) {
if(colDataIsNull_s(pCol, i)) continue;
char* data = colDataGetData(pCol, i);
void* pInputInfo = varDataVal(data);
avgTransferInfo(pCtx, pInputInfo, pInfo);
int32_t code = avgTransferInfo(pCtx, pInputInfo, pInfo);
if (code != 0) return code;
}
SET_VAL(GET_RES_INFO(pCtx), 1, 1);

View File

@ -619,7 +619,7 @@ int32_t doMinMaxHelper(SqlFunctionCtx* pCtx, int32_t isMinFunc, int32_t* nElems)
int16_t index = 0;
void* tval = NULL;
if (type == TSDB_DATA_TYPE_DECIMAL) {
if (IS_DECIMAL_TYPE(type)) {
tval = isMinFunc ? pInput->pColumnDataAgg[0]->decimal128Min : pInput->pColumnDataAgg[0]->decimal128Max;
} else {
tval = (isMinFunc) ? &pInput->pColumnDataAgg[0]->min : &pInput->pColumnDataAgg[0]->max;

View File

@ -4228,7 +4228,7 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn
if (datum->kind == FLT_SCL_DATUM_KIND_DECIMAL64 || datum->kind == FLT_SCL_DATUM_KIND_DECIMAL) {
int32_t code = convertToDecimal(pInput, &valDt, pData, &datum->type);
if (TSDB_CODE_SUCCESS != code) return code; // TODO wjm handle overflow error
valNode->node.resType = datum->type;
//valNode->node.resType = datum->type;
}
}
FLT_RET(0);

View File

@ -7,8 +7,10 @@ import time
import threading
import secrets
import numpy
from paramiko import Agent
import query
from tag_lite import datatype
from util.log import *
from util.sql import *
from util.cases import *
@ -43,12 +45,13 @@ scalar_convert_err = -2147470768
decimal_insert_validator_test = False
operator_test_round = 10
operator_test_round = 1
tb_insert_rows = 1000
binary_op_with_const_test = True
binary_op_with_col_test = True
unary_op_test = True
binary_op_in_where_test = True
test_decimal_funcs = True
class DecimalTypeGeneratorConfig:
def __init__(self):
@ -156,14 +159,14 @@ class DecimalColumnAggregator:
self.null_num: int = 0
self.none_num: int = 0
def add_value(self, value: str):
def add_value(self, value: str, scale: int):
self.count += 1
if value == "NULL":
self.null_num += 1
elif value == "None":
self.none_num += 1
else:
v: Decimal = Decimal(value)
v: Decimal = get_decimal(value, scale)
self.sum += v
if v > self.max:
self.max = v
@ -245,6 +248,10 @@ class DecimalColumnExpr:
def should_skip_for_decimal(self, cols: list):
return False
def check_query_results(self, query_col_res: List, tbname: str):
query_len = len(query_col_res)
pass
def check_for_filtering(self, query_col_res: List, tbname: str):
j: int = -1
for i in range(len(query_col_res)):
@ -479,6 +486,7 @@ class DataType:
return val
class DecimalType(DataType):
MAX_PRECISION = 38
def __init__(self, type, precision: int, scale: int):
self.precision_ = precision
self.scale_ = scale
@ -496,6 +504,14 @@ class DecimalType(DataType):
def get_decimal_type_mod(self) -> int:
return self.precision_ * 100 + self.scale()
def set_prec(self, prec: int):
self.precision_ = prec
self.type_mod = self.get_decimal_type_mod()
def set_scale(self, scale: int):
self.scale_ = scale
self.type_mod = self.get_decimal_type_mod()
def prec(self):
return self.precision_
@ -520,7 +536,7 @@ class DecimalType(DataType):
def generate_value(self) -> str:
val = self.decimal_generator.generate(self.generator_config)
self.aggregator.add_value(val) ## convert to Decimal first
self.aggregator.add_value(val, self.scale()) ## convert to Decimal first
# self.values.append(val) ## save it into files maybe
return val
@ -605,6 +621,12 @@ class Column:
if self.is_constant_col():
return self.get_constant_val_for_execute()
return self.get_typed_val_for_execute(self.saved_vals[tbname][idx])
def get_cardinality(self, tbname):
if self.is_constant_col():
return 1
else:
return len(self.saved_vals[tbname])
## tbName: for normal table, pass the tbname, for child table, pass the child table name
def generate_value(self, tbName: str = '', save: bool = True):
@ -824,6 +846,143 @@ class TableDataValidator:
col.check(res[colIdx], row_num * self.tbIdx)
colIdx += 1
class DecimalFunction(DecimalColumnExpr):
def __init__(self, format, executor, name: str):
super().__init__(format, executor)
self.func_name_ = name
def is_agg_func(self, op: str):
return False
def get_func_res(self):
return None
def get_input_types(self):
return [self.query_col]
@staticmethod
def get_decimal_agg_funcs() -> List:
return [DecimalMinFunction(), DecimalMaxFunction(), DecimalSumFunction(), DecimalAvgFunction()]
def check_results(self, query_col_res: List) -> bool:
return False
def check_for_agg_func(self, query_col_res: List, tbname: str):
col_expr = self.query_col
for i in range(col_expr.get_cardinality(tbname)):
col_val = col_expr.get_val_for_execute(tbname, i)
self.execute((col_val,))
if not self.check_results(query_col_res):
tdLog.exit(f"check failed for {self}, query got: {query_col_res}, expect {self.get_func_res()}")
class DecimalAggFunction(DecimalFunction):
def __init__(self, format, executor, name: str):
super().__init__(format, executor, name)
def is_agg_func(self, op):
return True
def should_skip_for_decimal(self, cols: list):
col: Column = cols[0]
if col.type_.is_decimal_type():
return False
return True
def check_results(self, query_col_res):
if len(query_col_res) == 0:
tdLog.info(f"query got no output: {self}, py calc: {self.get_func_res()}")
return True
else:
return self.get_func_res() == Decimal(query_col_res[0])
class DecimalMinFunction(DecimalAggFunction):
def __init__(self):
super().__init__("min({0})", DecimalMinFunction.execute_min, "min")
self.min_: Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.min
return self.min_
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
def execute_min(self, params):
if params[0] is None:
return
if self.min_ is None:
self.min_ = Decimal(params[0])
else:
self.min_ = min(self.min_, Decimal(params[0]))
return self.min_
class DecimalMaxFunction(DecimalAggFunction):
def __init__(self):
super().__init__("max({0})", DecimalMaxFunction.execute_max, "max")
self.max_: Decimal = None
def get_func_res(self) -> Decimal:
return self.max_
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
def execute_max(self, params):
if params[0] is None:
return
if self.max_ is None:
self.max_ = Decimal(params[0])
else:
self.max_ = max(self.max_, Decimal(params[0]))
return self.max_
class DecimalSumFunction(DecimalAggFunction):
def __init__(self):
super().__init__("sum({0})", DecimalSumFunction.execute_sum, "sum")
self.sum_:Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return decimal_type.aggregator.sum
return self.sum_
def generate_res_type(self) -> DataType:
self.res_type_ = self.query_col.type_
self.res_type_.set_prec(DecimalType.MAX_PRECISION)
def execute_sum(self, params):
if params[0] is None:
return
if self.sum_ is None:
self.sum_ = Decimal(params[0])
else:
self.sum_ += Decimal(params[0])
return self.sum_
class DecimalAvgFunction(DecimalAggFunction):
def __init__(self):
super().__init__("avg({0})", DecimalAvgFunction.execute_avg, "avg")
self.count_: Decimal = 0
self.sum_: Decimal = None
def get_func_res(self) -> Decimal:
decimal_type: DecimalType = self.query_col.type_
return get_decimal(
decimal_type.aggregator.sum
/ (decimal_type.aggregator.count - decimal_type.aggregator.null_num),
self.res_type_.scale(),
)
def generate_res_type(self) -> DataType:
sum_type = self.query_col.type_
sum_type.set_prec(DecimalType.MAX_PRECISION)
count_type = DataType(TypeEnum.BIGINT, 8, 0)
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(sum_type, count_type, "/")
def execute_avg(self, params):
if params[0] is None:
return
if self.sum_ is None:
self.sum_ = Decimal(params[0])
else:
self.sum_ += Decimal(params[0])
self.count_ += 1
return self.get_func_res()
class DecimalBinaryOperator(DecimalColumnExpr):
def __init__(self, format, executor, op: str):
@ -845,7 +1004,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
return True
if self.op_ != "%":
return False
## why skip decimal % float/double? it's wrong now.
## TODO wjm why skip decimal % float/double? it's wrong now.
left_is_real = left_col.type_.is_real_type() or left_col.type_.is_varchar_type()
right_is_real = right_col.type_.is_real_type() or right_col.type_.is_varchar_type()
if left_is_real or right_is_real:
@ -1554,7 +1713,7 @@ class TDTestCase:
threads.append(t)
for t in threads:
t.join()
def run_in_thread2(self, func, params) -> threading.Thread:
t = threading.Thread(target=func, args=params)
t.start()
@ -1638,10 +1797,6 @@ class TDTestCase:
self.test_query_decimal_where_clause()
def test_decimal_functions(self):
self.test_decimal_last_first_func()
funcs = ["max", "min", "sum", "avg", "count", "first", "last", "cast"]
def test_decimal_last_first_func(self):
pass
@ -1663,9 +1818,7 @@ class TDTestCase:
for col in tb_cols:
if col.name_ == '':
continue
left_is_decimal = col.type_.is_decimal_type()
for const_col in constant_cols:
right_is_decimal = const_col.type_.is_decimal_type()
if expr.should_skip_for_decimal([col, const_col]):
continue
const_col.generate_value()
@ -1740,7 +1893,7 @@ class TDTestCase:
self.norm_table_name,
self.norm_tb_columns,
binary_compare_ops)
## TODO wjm
## 3. (dec op const col) op const col
## 4. (dec op dec) op const col
@ -1765,6 +1918,32 @@ class TDTestCase:
def test_query_decimal_case_when(self):
pass
def test_decimal_agg_funcs(self, dbname, tbname, tb_cols: List[Column], get_agg_funcs_func):
agg_funcs: List[DecimalFunction] = get_agg_funcs_func()
for func in agg_funcs:
for col in tb_cols:
if col.name_ == '' or func.should_skip_for_decimal([col]):
continue
func.query_col = col
select_expr = func.generate([col])
sql = f"select {select_expr} from {dbname}.{tbname}"
res = TaosShell().query(sql)
if len(res) > 0:
res = res[0]
func.check_for_agg_func(res, tbname)
def test_decimal_functions(self):
if not test_decimal_funcs:
return
self.test_decimal_agg_funcs(
self.db_name,
self.norm_table_name,
self.norm_tb_columns,
DecimalFunction.get_decimal_agg_funcs,
)
self.test_decimal_last_first_func()
def test_query_decimal(self):
self.test_decimal_operators()
self.test_decimal_functions()