test decimal agg funcs
This commit is contained in:
parent
7a1ffd92ac
commit
caccf4add9
|
@ -177,20 +177,21 @@ typedef struct SColumnDataAgg {
|
|||
struct {
|
||||
uint64_t decimal128Sum[2];
|
||||
uint64_t decimal128Max[2];
|
||||
uint64_t decimal128Min[2]; // TODO wjm 1. use deicmal128Sum for decimal64, 2. add overflow flag
|
||||
uint64_t decimal128Min[2]; // TODO wjm 1. use deicmal128Sum for decimal64, 2. add overflow flag
|
||||
uint8_t overflow;
|
||||
};
|
||||
};
|
||||
} SColumnDataAgg;
|
||||
#pragma pack(pop)
|
||||
|
||||
#define COL_AGG_GET_SUM_PTR(pAggs, dataType) \
|
||||
(dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->sum : (void*)pAggs->decimal128Sum)
|
||||
(!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->sum : (void*)pAggs->decimal128Sum)
|
||||
|
||||
#define COL_AGG_GET_MAX_PTR(pAggs, dataType) \
|
||||
(dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->max : (void*)pAggs->decimal128Max)
|
||||
(!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->max : (void*)pAggs->decimal128Max)
|
||||
|
||||
#define COL_AGG_GET_MIN_PTR(pAggs, dataType) \
|
||||
(dataType != TSDB_DATA_TYPE_DECIMAL ? (void*)&pAggs->min : (void*)pAggs->decimal128Min)
|
||||
(!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->min : (void*)pAggs->decimal128Min)
|
||||
|
||||
typedef struct SBlockID {
|
||||
// The uid of table, from which current data block comes. And it is always 0, if current block is the
|
||||
|
|
|
@ -95,6 +95,7 @@ bool decimalCompare(EOperatorType op, const SDecimalCompareCtx* pLeft, const
|
|||
int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pRightT, const SDataType* pOutT,
|
||||
const void* pLeftData, const void* pRightData, void* pOutputData);
|
||||
int32_t convertToDecimal(const void* pData, const SDataType* pInputType, void* pOut, const SDataType* pOutType);
|
||||
bool decimal128AddCheckOverflow(const Decimal128* pLeft, const DecimalType* pRight, uint8_t rightWordNum);
|
||||
|
||||
DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal64);
|
||||
DEFINE_TYPE_FROM_DECIMAL_FUNCS(, Decimal128);
|
||||
|
|
|
@ -4222,29 +4222,36 @@ static FORCE_INLINE void tColDataCalcSMAVarType(SColData *pColData, SColumnDataA
|
|||
}
|
||||
}
|
||||
|
||||
#define CALC_DECIMAL_SUM_MAX_MIN(TYPE, pOps, pColData, pSum, pMax, pMin) \
|
||||
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) { \
|
||||
pVal = ((TYPE *)pColData->pData) + iVal; \
|
||||
pOps->add(pSum, pVal, WORD_NUM(TYPE)); \
|
||||
if (pOps->gt(pVal, pMax, WORD_NUM(TYPE))) { \
|
||||
*(pMax) = *pVal; \
|
||||
} \
|
||||
if (pOps->lt(pVal, pMin, WORD_NUM(TYPE))) { \
|
||||
*(pMin) = *pVal; \
|
||||
} \
|
||||
}
|
||||
#define CALC_DECIMAL_SUM_MAX_MIN(TYPE, pSumOp, pCompOp, pColData, pSum, pMax, pMin) \
|
||||
do { \
|
||||
if (decimal128AddCheckOverflow((Decimal *)pSum, pVal, WORD_NUM(TYPE))) *pOverflow = true; \
|
||||
pSumOp->add(pSum, pVal, WORD_NUM(TYPE)); \
|
||||
if (pCompOp->gt(pVal, pMax, WORD_NUM(TYPE))) { \
|
||||
*(pMax) = *pVal; \
|
||||
} \
|
||||
if (pCompOp->lt(pVal, pMin, WORD_NUM(TYPE))) { \
|
||||
*(pMin) = *pVal; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColumnDataAgg* pAggs) {
|
||||
Decimal64* pSum = (Decimal64*)&pAggs->sum, *pMax = (Decimal64*)&pAggs->max, *pMin = (Decimal64*)&pAggs->min;
|
||||
*pSum = DECIMAL64_ZERO;
|
||||
Decimal128 *pSum = (Decimal128 *)pAggs->decimal128Sum;
|
||||
Decimal64 *pMax = (Decimal64 *)pAggs->decimal128Max, *pMin = (Decimal64 *)pAggs->decimal128Min;
|
||||
uint8_t *pOverflow = &pAggs->overflow;
|
||||
*pSum = DECIMAL128_ZERO;
|
||||
*pMax = DECIMAL64_MIN;
|
||||
*pMin = DECIMAL64_MAX;
|
||||
pAggs->numOfNull = 0;
|
||||
pAggs->colId |= 0x80000000; // TODO wjm define it
|
||||
|
||||
Decimal64 *pVal = NULL;
|
||||
SDecimalOps *pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL64);
|
||||
const SDecimalOps *pSumOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
|
||||
const SDecimalOps *pCompOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL64);
|
||||
if (HAS_VALUE == pColData->flag) {
|
||||
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pOps, pColData, pSum, pMax, pMin);
|
||||
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
|
||||
pVal = ((Decimal64*)pColData->pData) + iVal;
|
||||
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pSumOps, pCompOps, pColData, pSum, pMax, pMin);
|
||||
}
|
||||
} else {
|
||||
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
|
||||
switch (tColDataGetBitValue(pColData, iVal)) {
|
||||
|
@ -4253,7 +4260,8 @@ static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColum
|
|||
pAggs->numOfNull++;
|
||||
break;
|
||||
case 2:
|
||||
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pOps, pColData, pSum, pMax, pMin);// TODO wjm what if overflow
|
||||
pVal = ((Decimal64 *)pColData->pData) + iVal;
|
||||
CALC_DECIMAL_SUM_MAX_MIN(Decimal64, pSumOps, pCompOps, pColData, pSum, pMax, pMin);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
@ -4263,7 +4271,9 @@ static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColum
|
|||
}
|
||||
|
||||
static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColumnDataAgg* pAggs) {
|
||||
Decimal128* pSum = (Decimal128*)pAggs->decimal128Sum, *pMax = (Decimal128*)pAggs->decimal128Max, *pMin = (Decimal128*)pAggs->decimal128Min;
|
||||
Decimal128 *pSum = (Decimal128 *)pAggs->decimal128Sum, *pMax = (Decimal128 *)pAggs->decimal128Max,
|
||||
*pMin = (Decimal128 *)pAggs->decimal128Min;
|
||||
uint8_t *pOverflow = &pAggs->overflow;
|
||||
*pSum = DECIMAL128_ZERO;
|
||||
*pMax = DECIMAL128_MIN;
|
||||
*pMin = DECIMAL128_MAX;
|
||||
|
@ -4273,7 +4283,10 @@ static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColu
|
|||
Decimal128 *pVal = NULL;
|
||||
SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
|
||||
if (HAS_VALUE == pColData->flag) {
|
||||
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pColData, pSum, pMax, pMin);
|
||||
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
|
||||
pVal = ((Decimal128*)pColData->pData) + iVal;
|
||||
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pOps, pColData, pSum, pMax, pMin);
|
||||
}
|
||||
} else {
|
||||
for (int32_t iVal = 0; iVal < pColData->nVal; ++iVal) {
|
||||
switch (tColDataGetBitValue(pColData, iVal)) {
|
||||
|
@ -4282,7 +4295,8 @@ static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColu
|
|||
pAggs->numOfNull++;
|
||||
break;
|
||||
case 2:
|
||||
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pColData, pSum, pMax, pMin);
|
||||
pVal = ((Decimal128*)pColData->pData) + iVal;
|
||||
CALC_DECIMAL_SUM_MAX_MIN(Decimal128, pOps, pOps, pColData, pSum, pMax, pMin);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
|
|
@ -1577,12 +1577,13 @@ int32_t tPutColumnDataAgg(SBuffer *buffer, SColumnDataAgg *pColAgg) {
|
|||
if (pColAgg->colId & 0x80000000) {
|
||||
if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code;
|
||||
if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code;
|
||||
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Sum[0]))) return code;
|
||||
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Sum[1]))) return code;
|
||||
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Max[0]))) return code;
|
||||
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Max[1]))) return code;
|
||||
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Min[0]))) return code;
|
||||
if ((code = tBufferPutI64(buffer, pColAgg->decimal128Min[1]))) return code;
|
||||
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Sum[0]))) return code;
|
||||
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Sum[1]))) return code;
|
||||
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Max[0]))) return code;
|
||||
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Max[1]))) return code;
|
||||
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Min[0]))) return code;
|
||||
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Min[1]))) return code;
|
||||
if ((code = tBufferPutU8(buffer, pColAgg->overflow))) return code;
|
||||
} else {
|
||||
if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code;
|
||||
if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code;
|
||||
|
@ -1601,12 +1602,13 @@ int32_t tGetColumnDataAgg(SBufferReader *br, SColumnDataAgg *pColAgg) {
|
|||
if ((code = tBufferGetI16v(br, &pColAgg->numOfNull))) return code;
|
||||
if (pColAgg->colId & 0x80000000) {
|
||||
pColAgg->colId &= 0xFFFF;
|
||||
if ((code = tBufferGetI64(br, &pColAgg->decimal128Sum[0]))) return code;
|
||||
if ((code = tBufferGetI64(br, &pColAgg->decimal128Sum[1]))) return code;
|
||||
if ((code = tBufferGetI64(br, &pColAgg->decimal128Max[0]))) return code;
|
||||
if ((code = tBufferGetI64(br, &pColAgg->decimal128Max[1]))) return code;
|
||||
if ((code = tBufferGetI64(br, &pColAgg->decimal128Min[0]))) return code;
|
||||
if ((code = tBufferGetI64(br, &pColAgg->decimal128Min[1]))) return code;
|
||||
if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[0]))) return code;
|
||||
if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[1]))) return code;
|
||||
if ((code = tBufferGetU64(br, &pColAgg->decimal128Max[0]))) return code;
|
||||
if ((code = tBufferGetU64(br, &pColAgg->decimal128Max[1]))) return code;
|
||||
if ((code = tBufferGetU64(br, &pColAgg->decimal128Min[0]))) return code;
|
||||
if ((code = tBufferGetU64(br, &pColAgg->decimal128Min[1]))) return code;
|
||||
if ((code = tBufferGetU8(br, &pColAgg->overflow))) return code;
|
||||
} else {
|
||||
if ((code = tBufferGetI64(br, &pColAgg->sum))) return code;
|
||||
if ((code = tBufferGetI64(br, &pColAgg->max))) return code;
|
||||
|
|
|
@ -1920,6 +1920,18 @@ static int32_t decimal128CountRoundingDelta(const Decimal128* pDec, int8_t scale
|
|||
return res;
|
||||
}
|
||||
|
||||
bool decimal128AddCheckOverflow(const Decimal128* pLeft, const DecimalType* pRight, uint8_t rightWordNum) {
|
||||
if (DECIMAL128_SIGN(pLeft) == 0) {
|
||||
Decimal128 max = decimal128Max;
|
||||
decimal128Subtract(&max, pLeft, WORD_NUM(Decimal128));
|
||||
return decimal128Lt(&max, pRight, rightWordNum);
|
||||
} else {
|
||||
Decimal128 min = decimal128Min;
|
||||
decimal128Subtract(&min, pLeft, WORD_NUM(Decimal128));
|
||||
return decimal128Gt(&min, pRight, rightWordNum);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t TEST_decimal64From_int64_t(Decimal64* pDec, uint8_t prec, uint8_t scale, int64_t v) {
|
||||
return decimal64FromInt64(pDec, prec, scale, v);
|
||||
}
|
||||
|
|
|
@ -1509,6 +1509,17 @@ TEST_F(DecimalTest, decimalFromStr) {
|
|||
Numeric<128> numeric128 = {38, 10, "0"};
|
||||
}
|
||||
|
||||
TEST(decimal, test_add_check_overflow) {
|
||||
Numeric<128> dec128 = {38, 10, "9999999999999999999999999999.9999999999"};
|
||||
Numeric<64> dec64 = {18, 2, "123.12"};
|
||||
bool overflow = decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64));
|
||||
ASSERT_TRUE(overflow);
|
||||
dec128 = {38, 10, "-9999999999999999999999999999.9999999999"};
|
||||
ASSERT_FALSE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64)));
|
||||
dec64 = {18, 2, "-123.1"};
|
||||
ASSERT_TRUE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64)));
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
|
|
|
@ -87,13 +87,19 @@ typedef struct SDecimalSumRes {
|
|||
|
||||
#define SUM_RES_GET_DECIMAL_SUM(pSumRes) ((SDecimalSumRes*)(pSumRes))->sum
|
||||
// TODO wjm check for overflow
|
||||
#define SUM_RES_INC_DECIMAL_SUM(pSumRes, pVal, type) \
|
||||
do { \
|
||||
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
|
||||
if (type == TSDB_DATA_TYPE_DECIMAL64) \
|
||||
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal64)); \
|
||||
else \
|
||||
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, WORD_NUM(Decimal)); \
|
||||
#define SUM_RES_INC_DECIMAL_SUM(pSumRes, pVal, type) \
|
||||
do { \
|
||||
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
|
||||
int32_t wordNum = 0; \
|
||||
if (type == TSDB_DATA_TYPE_DECIMAL64) { \
|
||||
wordNum = WORD_NUM(Decimal64); \
|
||||
overflow = decimal128AddCheckOverflow(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \
|
||||
} else { \
|
||||
wordNum = WORD_NUM(Decimal); \
|
||||
overflow = decimal128AddCheckOverflow(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \
|
||||
} \
|
||||
if (overflow) break; \
|
||||
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pVal, wordNum); \
|
||||
} while (0)
|
||||
|
||||
typedef struct SMinmaxResInfo {
|
||||
|
|
|
@ -106,17 +106,19 @@ typedef enum {
|
|||
} \
|
||||
} while (0)
|
||||
|
||||
#define LIST_ADD_DECIMAL_N(_res, _col, _start, _rows, _t, numOfElem) \
|
||||
do { \
|
||||
_t* d = (_t*)(_col->pData); \
|
||||
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
|
||||
for (int32_t i = (_start); i < (_rows) + (_start); ++i) { \
|
||||
if (((_col)->hasNull) && colDataIsNull_f((_col)->nullbitmap, i)) { \
|
||||
continue; \
|
||||
}; \
|
||||
pOps->add(_res, d + i, WORD_NUM(_t)); \
|
||||
(numOfElem)++; \
|
||||
} \
|
||||
#define LIST_ADD_DECIMAL_N(_res, _col, _start, _rows, _t, numOfElem) \
|
||||
do { \
|
||||
_t* d = (_t*)(_col->pData); \
|
||||
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL); \
|
||||
for (int32_t i = (_start); i < (_rows) + (_start); ++i) { \
|
||||
if (((_col)->hasNull) && colDataIsNull_f((_col)->nullbitmap, i)) { \
|
||||
continue; \
|
||||
}; \
|
||||
overflow = overflow || decimal128AddCheckOverflow((Decimal*)_res, d + i, WORD_NUM(_t)); \
|
||||
if (overflow) break; \
|
||||
pOps->add(_res, d + i, WORD_NUM(_t)); \
|
||||
(numOfElem)++; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define LIST_SUB_N(_res, _col, _start, _rows, _t, numOfElem) \
|
||||
|
@ -652,12 +654,12 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) {
|
|||
SUM_RES_INC_DSUM(pSumRes, GET_DOUBLE_VAL((const char*)&(pAgg->sum)));
|
||||
} else if (IS_DECIMAL_TYPE(type)) {
|
||||
SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL);
|
||||
const SDecimalOps* pOps = getDecimalOps(type);
|
||||
if (TSDB_DATA_TYPE_DECIMAL64 == type) {
|
||||
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->sum, WORD_NUM(Decimal64));
|
||||
} else if (TSDB_DATA_TYPE_DECIMAL == type) {
|
||||
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->decimal128Sum, WORD_NUM(Decimal));
|
||||
const SDecimalOps* pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
|
||||
if (pAgg->overflow || decimal128AddCheckOverflow((Decimal*)&SUM_RES_GET_DECIMAL_SUM(pSumRes),
|
||||
&pAgg->decimal128Sum, WORD_NUM(Decimal))) {
|
||||
return TSDB_CODE_DECIMAL_OVERFLOW;
|
||||
}
|
||||
pOps->add(&SUM_RES_GET_DECIMAL_SUM(pSumRes), &pAgg->decimal128Sum, WORD_NUM(Decimal));
|
||||
}
|
||||
} else { // computing based on the true data block
|
||||
SColumnInfoData* pCol = pInput->pData[0];
|
||||
|
@ -691,12 +693,13 @@ int32_t sumFunction(SqlFunctionCtx* pCtx) {
|
|||
LIST_ADD_N(SUM_RES_GET_DSUM(pSumRes), pCol, start, numOfRows, float, numOfElem);
|
||||
} else if (IS_DECIMAL_TYPE(type)) {
|
||||
SUM_RES_SET_TYPE(pSumRes, pCtx->inputType, TSDB_DATA_TYPE_DECIMAL);
|
||||
int32_t overflow = false;
|
||||
if (TSDB_DATA_TYPE_DECIMAL64 == type) {
|
||||
LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal64, numOfElem);
|
||||
} else if (TSDB_DATA_TYPE_DECIMAL == type) {
|
||||
LIST_ADD_DECIMAL_N(&SUM_RES_GET_DECIMAL_SUM(pSumRes), pCol, start, numOfRows, Decimal128, numOfElem);
|
||||
}
|
||||
// TODO wjm check overflow
|
||||
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ int32_t avgFunctionSetup(SqlFunctionCtx* pCtx, SResultRowEntryInfo* pResultInfo)
|
|||
return TSDB_CODE_SUCCESS;
|
||||
}
|
||||
|
||||
static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg) {
|
||||
static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type, const SColumnDataAgg* pAgg, int32_t* pNumOfElem) {
|
||||
int32_t numOfElem = numOfRows - pAgg->numOfNull;
|
||||
|
||||
AVG_RES_INC_COUNT(pRes, type, numOfElem);
|
||||
|
@ -137,16 +137,17 @@ static int32_t calculateAvgBySMAInfo(void* pRes, int32_t numOfRows, int32_t type
|
|||
} else if (IS_FLOAT_TYPE(type)) {
|
||||
SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pRes), GET_DOUBLE_VAL((const char*)&(pAgg->sum)));
|
||||
} else if (IS_DECIMAL_TYPE(type)) {
|
||||
if (type == TSDB_DATA_TYPE_DECIMAL64)
|
||||
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->sum, TSDB_DATA_TYPE_DECIMAL64);
|
||||
else
|
||||
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->decimal128Sum, TSDB_DATA_TYPE_DECIMAL);
|
||||
bool overflow = pAgg->overflow;
|
||||
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
|
||||
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), &pAgg->decimal128Sum, TSDB_DATA_TYPE_DECIMAL);
|
||||
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
|
||||
}
|
||||
|
||||
return numOfElem;
|
||||
*pNumOfElem = numOfElem;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, void* pRes) {
|
||||
static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputColumnInfoData *pInput, void* pRes, int32_t* pNumOfElem) {
|
||||
int32_t start = pInput->startRowIndex;
|
||||
int32_t numOfRows = pInput->numOfRows;
|
||||
int32_t numOfElems = 0;
|
||||
|
@ -306,14 +307,17 @@ static int32_t doAddNumericVector(SColumnInfoData* pCol, int32_t type, SInputCol
|
|||
|
||||
numOfElems += 1;
|
||||
AVG_RES_INC_COUNT(pRes, type, 1);
|
||||
bool overflow = false;
|
||||
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pRes), (const void*)(pDec + i * tDataTypes[type].bytes), type);
|
||||
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return numOfElems;
|
||||
*pNumOfElem = numOfElems;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t avgFunction(SqlFunctionCtx* pCtx) {
|
||||
|
@ -341,7 +345,8 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) {
|
|||
if (IS_DECIMAL_TYPE(type)) AVG_RES_SET_INPUT_SCALE(pAvgRes, pInput->pData[0]->info.scale);
|
||||
|
||||
if (pInput->colDataSMAIsSet) { // try to use SMA if available
|
||||
numOfElem = calculateAvgBySMAInfo(pAvgRes, numOfRows, type, pAgg);
|
||||
int32_t code = calculateAvgBySMAInfo(pAvgRes, numOfRows, type, pAgg, &numOfElem);
|
||||
if (code != 0) return code;
|
||||
} else if (!pCol->hasNull) { // try to employ the simd instructions to speed up the loop
|
||||
numOfElem = pInput->numOfRows;
|
||||
AVG_RES_INC_COUNT(pAvgRes, pCtx->inputType, pInput->numOfRows);
|
||||
|
@ -424,6 +429,7 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) {
|
|||
const char* pDec = pCol->pData;
|
||||
// TODO wjm check for overflow
|
||||
for (int32_t i = pInput->startRowIndex; i < pInput->numOfRows + pInput->startRowIndex; ++i) {
|
||||
bool overflow = false;
|
||||
if (type == TSDB_DATA_TYPE_DECIMAL64) {
|
||||
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes),
|
||||
TSDB_DATA_TYPE_DECIMAL64);
|
||||
|
@ -431,13 +437,14 @@ int32_t avgFunction(SqlFunctionCtx* pCtx) {
|
|||
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pAvgRes), (const void*)(pDec + i * tDataTypes[type].bytes),
|
||||
TSDB_DATA_TYPE_DECIMAL);
|
||||
}
|
||||
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
return TSDB_CODE_FUNC_FUNTION_PARA_TYPE;
|
||||
}
|
||||
} else {
|
||||
numOfElem = doAddNumericVector(pCol, type, pInput, pAvgRes);
|
||||
int32_t code = doAddNumericVector(pCol, type, pInput, pAvgRes, &numOfElem);
|
||||
}
|
||||
|
||||
_over:
|
||||
|
@ -446,12 +453,12 @@ _over:
|
|||
return TSDB_CODE_SUCCESS;
|
||||
}
|
||||
|
||||
static void avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) {
|
||||
static int32_t avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) {
|
||||
int32_t inputDT = pCtx->pExpr->pExpr->_function.pFunctNode->srcFuncInputType.type;
|
||||
int32_t type = AVG_RES_GET_TYPE(pInput, inputDT);
|
||||
pCtx->inputType = type;
|
||||
if (IS_NULL_TYPE(type)) {
|
||||
return;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
@ -464,12 +471,15 @@ static void avgTransferInfo(SqlFunctionCtx* pCtx, void* pInput, void* pOutput) {
|
|||
CHECK_OVERFLOW_SUM_UNSIGNED_BIG(pOutput, (overflow ? SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)) : SUM_RES_GET_USUM(&AVG_RES_GET_SUM(pInput))), overflow);
|
||||
} else if (IS_DECIMAL_TYPE(type)) {
|
||||
AVG_RES_SET_INPUT_SCALE(pOutput, AVG_RES_GET_INPUT_SCALE(pInput));
|
||||
bool overflow = false;
|
||||
SUM_RES_INC_DECIMAL_SUM(&AVG_RES_GET_DECIMAL_SUM(pOutput), &AVG_RES_GET_DECIMAL_SUM(pInput), TSDB_DATA_TYPE_DECIMAL);
|
||||
if (overflow) return TSDB_CODE_DECIMAL_OVERFLOW;
|
||||
} else {
|
||||
SUM_RES_INC_DSUM(&AVG_RES_GET_SUM(pOutput), SUM_RES_GET_DSUM(&AVG_RES_GET_SUM(pInput)));
|
||||
}
|
||||
|
||||
AVG_RES_INC_COUNT(pOutput, type, AVG_RES_GET_COUNT(pInput, true, type));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) {
|
||||
|
@ -493,7 +503,8 @@ int32_t avgFunctionMerge(SqlFunctionCtx* pCtx) {
|
|||
if(colDataIsNull_s(pCol, i)) continue;
|
||||
char* data = colDataGetData(pCol, i);
|
||||
void* pInputInfo = varDataVal(data);
|
||||
avgTransferInfo(pCtx, pInputInfo, pInfo);
|
||||
int32_t code = avgTransferInfo(pCtx, pInputInfo, pInfo);
|
||||
if (code != 0) return code;
|
||||
}
|
||||
|
||||
SET_VAL(GET_RES_INFO(pCtx), 1, 1);
|
||||
|
|
|
@ -619,7 +619,7 @@ int32_t doMinMaxHelper(SqlFunctionCtx* pCtx, int32_t isMinFunc, int32_t* nElems)
|
|||
|
||||
int16_t index = 0;
|
||||
void* tval = NULL;
|
||||
if (type == TSDB_DATA_TYPE_DECIMAL) {
|
||||
if (IS_DECIMAL_TYPE(type)) {
|
||||
tval = isMinFunc ? pInput->pColumnDataAgg[0]->decimal128Min : pInput->pColumnDataAgg[0]->decimal128Max;
|
||||
} else {
|
||||
tval = (isMinFunc) ? &pInput->pColumnDataAgg[0]->min : &pInput->pColumnDataAgg[0]->max;
|
||||
|
|
|
@ -4228,7 +4228,7 @@ static int32_t fltSclBuildDecimalDatumFromValueNode(SFltSclDatum* datum, SColumn
|
|||
if (datum->kind == FLT_SCL_DATUM_KIND_DECIMAL64 || datum->kind == FLT_SCL_DATUM_KIND_DECIMAL) {
|
||||
int32_t code = convertToDecimal(pInput, &valDt, pData, &datum->type);
|
||||
if (TSDB_CODE_SUCCESS != code) return code; // TODO wjm handle overflow error
|
||||
valNode->node.resType = datum->type;
|
||||
//valNode->node.resType = datum->type;
|
||||
}
|
||||
}
|
||||
FLT_RET(0);
|
||||
|
|
|
@ -7,8 +7,10 @@ import time
|
|||
import threading
|
||||
import secrets
|
||||
import numpy
|
||||
from paramiko import Agent
|
||||
|
||||
import query
|
||||
from tag_lite import datatype
|
||||
from util.log import *
|
||||
from util.sql import *
|
||||
from util.cases import *
|
||||
|
@ -43,12 +45,13 @@ scalar_convert_err = -2147470768
|
|||
|
||||
|
||||
decimal_insert_validator_test = False
|
||||
operator_test_round = 10
|
||||
operator_test_round = 1
|
||||
tb_insert_rows = 1000
|
||||
binary_op_with_const_test = True
|
||||
binary_op_with_col_test = True
|
||||
unary_op_test = True
|
||||
binary_op_in_where_test = True
|
||||
test_decimal_funcs = True
|
||||
|
||||
class DecimalTypeGeneratorConfig:
|
||||
def __init__(self):
|
||||
|
@ -156,14 +159,14 @@ class DecimalColumnAggregator:
|
|||
self.null_num: int = 0
|
||||
self.none_num: int = 0
|
||||
|
||||
def add_value(self, value: str):
|
||||
def add_value(self, value: str, scale: int):
|
||||
self.count += 1
|
||||
if value == "NULL":
|
||||
self.null_num += 1
|
||||
elif value == "None":
|
||||
self.none_num += 1
|
||||
else:
|
||||
v: Decimal = Decimal(value)
|
||||
v: Decimal = get_decimal(value, scale)
|
||||
self.sum += v
|
||||
if v > self.max:
|
||||
self.max = v
|
||||
|
@ -245,6 +248,10 @@ class DecimalColumnExpr:
|
|||
def should_skip_for_decimal(self, cols: list):
|
||||
return False
|
||||
|
||||
def check_query_results(self, query_col_res: List, tbname: str):
|
||||
query_len = len(query_col_res)
|
||||
pass
|
||||
|
||||
def check_for_filtering(self, query_col_res: List, tbname: str):
|
||||
j: int = -1
|
||||
for i in range(len(query_col_res)):
|
||||
|
@ -479,6 +486,7 @@ class DataType:
|
|||
return val
|
||||
|
||||
class DecimalType(DataType):
|
||||
MAX_PRECISION = 38
|
||||
def __init__(self, type, precision: int, scale: int):
|
||||
self.precision_ = precision
|
||||
self.scale_ = scale
|
||||
|
@ -496,6 +504,14 @@ class DecimalType(DataType):
|
|||
|
||||
def get_decimal_type_mod(self) -> int:
|
||||
return self.precision_ * 100 + self.scale()
|
||||
|
||||
def set_prec(self, prec: int):
|
||||
self.precision_ = prec
|
||||
self.type_mod = self.get_decimal_type_mod()
|
||||
|
||||
def set_scale(self, scale: int):
|
||||
self.scale_ = scale
|
||||
self.type_mod = self.get_decimal_type_mod()
|
||||
|
||||
def prec(self):
|
||||
return self.precision_
|
||||
|
@ -520,7 +536,7 @@ class DecimalType(DataType):
|
|||
|
||||
def generate_value(self) -> str:
|
||||
val = self.decimal_generator.generate(self.generator_config)
|
||||
self.aggregator.add_value(val) ## convert to Decimal first
|
||||
self.aggregator.add_value(val, self.scale()) ## convert to Decimal first
|
||||
# self.values.append(val) ## save it into files maybe
|
||||
return val
|
||||
|
||||
|
@ -605,6 +621,12 @@ class Column:
|
|||
if self.is_constant_col():
|
||||
return self.get_constant_val_for_execute()
|
||||
return self.get_typed_val_for_execute(self.saved_vals[tbname][idx])
|
||||
|
||||
def get_cardinality(self, tbname):
|
||||
if self.is_constant_col():
|
||||
return 1
|
||||
else:
|
||||
return len(self.saved_vals[tbname])
|
||||
|
||||
## tbName: for normal table, pass the tbname, for child table, pass the child table name
|
||||
def generate_value(self, tbName: str = '', save: bool = True):
|
||||
|
@ -824,6 +846,143 @@ class TableDataValidator:
|
|||
col.check(res[colIdx], row_num * self.tbIdx)
|
||||
colIdx += 1
|
||||
|
||||
class DecimalFunction(DecimalColumnExpr):
|
||||
def __init__(self, format, executor, name: str):
|
||||
super().__init__(format, executor)
|
||||
self.func_name_ = name
|
||||
|
||||
def is_agg_func(self, op: str):
|
||||
return False
|
||||
|
||||
def get_func_res(self):
|
||||
return None
|
||||
|
||||
def get_input_types(self):
|
||||
return [self.query_col]
|
||||
|
||||
@staticmethod
|
||||
def get_decimal_agg_funcs() -> List:
|
||||
return [DecimalMinFunction(), DecimalMaxFunction(), DecimalSumFunction(), DecimalAvgFunction()]
|
||||
|
||||
def check_results(self, query_col_res: List) -> bool:
|
||||
return False
|
||||
|
||||
def check_for_agg_func(self, query_col_res: List, tbname: str):
|
||||
col_expr = self.query_col
|
||||
for i in range(col_expr.get_cardinality(tbname)):
|
||||
col_val = col_expr.get_val_for_execute(tbname, i)
|
||||
self.execute((col_val,))
|
||||
if not self.check_results(query_col_res):
|
||||
tdLog.exit(f"check failed for {self}, query got: {query_col_res}, expect {self.get_func_res()}")
|
||||
|
||||
class DecimalAggFunction(DecimalFunction):
|
||||
def __init__(self, format, executor, name: str):
|
||||
super().__init__(format, executor, name)
|
||||
|
||||
def is_agg_func(self, op):
|
||||
return True
|
||||
|
||||
def should_skip_for_decimal(self, cols: list):
|
||||
col: Column = cols[0]
|
||||
if col.type_.is_decimal_type():
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_results(self, query_col_res):
|
||||
if len(query_col_res) == 0:
|
||||
tdLog.info(f"query got no output: {self}, py calc: {self.get_func_res()}")
|
||||
return True
|
||||
else:
|
||||
return self.get_func_res() == Decimal(query_col_res[0])
|
||||
|
||||
class DecimalMinFunction(DecimalAggFunction):
|
||||
def __init__(self):
|
||||
super().__init__("min({0})", DecimalMinFunction.execute_min, "min")
|
||||
self.min_: Decimal = None
|
||||
|
||||
def get_func_res(self) -> Decimal:
|
||||
decimal_type: DecimalType = self.query_col.type_
|
||||
return decimal_type.aggregator.min
|
||||
return self.min_
|
||||
|
||||
def generate_res_type(self) -> DataType:
|
||||
self.res_type_ = self.query_col.type_
|
||||
|
||||
def execute_min(self, params):
|
||||
if params[0] is None:
|
||||
return
|
||||
if self.min_ is None:
|
||||
self.min_ = Decimal(params[0])
|
||||
else:
|
||||
self.min_ = min(self.min_, Decimal(params[0]))
|
||||
return self.min_
|
||||
|
||||
class DecimalMaxFunction(DecimalAggFunction):
|
||||
def __init__(self):
|
||||
super().__init__("max({0})", DecimalMaxFunction.execute_max, "max")
|
||||
self.max_: Decimal = None
|
||||
|
||||
def get_func_res(self) -> Decimal:
|
||||
return self.max_
|
||||
|
||||
def generate_res_type(self) -> DataType:
|
||||
self.res_type_ = self.query_col.type_
|
||||
|
||||
def execute_max(self, params):
|
||||
if params[0] is None:
|
||||
return
|
||||
if self.max_ is None:
|
||||
self.max_ = Decimal(params[0])
|
||||
else:
|
||||
self.max_ = max(self.max_, Decimal(params[0]))
|
||||
return self.max_
|
||||
|
||||
class DecimalSumFunction(DecimalAggFunction):
|
||||
def __init__(self):
|
||||
super().__init__("sum({0})", DecimalSumFunction.execute_sum, "sum")
|
||||
self.sum_:Decimal = None
|
||||
def get_func_res(self) -> Decimal:
|
||||
decimal_type: DecimalType = self.query_col.type_
|
||||
return decimal_type.aggregator.sum
|
||||
return self.sum_
|
||||
def generate_res_type(self) -> DataType:
|
||||
self.res_type_ = self.query_col.type_
|
||||
self.res_type_.set_prec(DecimalType.MAX_PRECISION)
|
||||
def execute_sum(self, params):
|
||||
if params[0] is None:
|
||||
return
|
||||
if self.sum_ is None:
|
||||
self.sum_ = Decimal(params[0])
|
||||
else:
|
||||
self.sum_ += Decimal(params[0])
|
||||
return self.sum_
|
||||
|
||||
class DecimalAvgFunction(DecimalAggFunction):
|
||||
def __init__(self):
|
||||
super().__init__("avg({0})", DecimalAvgFunction.execute_avg, "avg")
|
||||
self.count_: Decimal = 0
|
||||
self.sum_: Decimal = None
|
||||
def get_func_res(self) -> Decimal:
|
||||
decimal_type: DecimalType = self.query_col.type_
|
||||
return get_decimal(
|
||||
decimal_type.aggregator.sum
|
||||
/ (decimal_type.aggregator.count - decimal_type.aggregator.null_num),
|
||||
self.res_type_.scale(),
|
||||
)
|
||||
def generate_res_type(self) -> DataType:
|
||||
sum_type = self.query_col.type_
|
||||
sum_type.set_prec(DecimalType.MAX_PRECISION)
|
||||
count_type = DataType(TypeEnum.BIGINT, 8, 0)
|
||||
self.res_type_ = DecimalBinaryOperator.calc_decimal_prec_scale(sum_type, count_type, "/")
|
||||
def execute_avg(self, params):
|
||||
if params[0] is None:
|
||||
return
|
||||
if self.sum_ is None:
|
||||
self.sum_ = Decimal(params[0])
|
||||
else:
|
||||
self.sum_ += Decimal(params[0])
|
||||
self.count_ += 1
|
||||
return self.get_func_res()
|
||||
|
||||
class DecimalBinaryOperator(DecimalColumnExpr):
|
||||
def __init__(self, format, executor, op: str):
|
||||
|
@ -845,7 +1004,7 @@ class DecimalBinaryOperator(DecimalColumnExpr):
|
|||
return True
|
||||
if self.op_ != "%":
|
||||
return False
|
||||
## why skip decimal % float/double? it's wrong now.
|
||||
## TODO wjm why skip decimal % float/double? it's wrong now.
|
||||
left_is_real = left_col.type_.is_real_type() or left_col.type_.is_varchar_type()
|
||||
right_is_real = right_col.type_.is_real_type() or right_col.type_.is_varchar_type()
|
||||
if left_is_real or right_is_real:
|
||||
|
@ -1554,7 +1713,7 @@ class TDTestCase:
|
|||
threads.append(t)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
def run_in_thread2(self, func, params) -> threading.Thread:
|
||||
t = threading.Thread(target=func, args=params)
|
||||
t.start()
|
||||
|
@ -1638,10 +1797,6 @@ class TDTestCase:
|
|||
|
||||
self.test_query_decimal_where_clause()
|
||||
|
||||
def test_decimal_functions(self):
|
||||
self.test_decimal_last_first_func()
|
||||
funcs = ["max", "min", "sum", "avg", "count", "first", "last", "cast"]
|
||||
|
||||
def test_decimal_last_first_func(self):
|
||||
pass
|
||||
|
||||
|
@ -1663,9 +1818,7 @@ class TDTestCase:
|
|||
for col in tb_cols:
|
||||
if col.name_ == '':
|
||||
continue
|
||||
left_is_decimal = col.type_.is_decimal_type()
|
||||
for const_col in constant_cols:
|
||||
right_is_decimal = const_col.type_.is_decimal_type()
|
||||
if expr.should_skip_for_decimal([col, const_col]):
|
||||
continue
|
||||
const_col.generate_value()
|
||||
|
@ -1740,7 +1893,7 @@ class TDTestCase:
|
|||
self.norm_table_name,
|
||||
self.norm_tb_columns,
|
||||
binary_compare_ops)
|
||||
|
||||
|
||||
## TODO wjm
|
||||
## 3. (dec op const col) op const col
|
||||
## 4. (dec op dec) op const col
|
||||
|
@ -1765,6 +1918,32 @@ class TDTestCase:
|
|||
def test_query_decimal_case_when(self):
|
||||
pass
|
||||
|
||||
def test_decimal_agg_funcs(self, dbname, tbname, tb_cols: List[Column], get_agg_funcs_func):
|
||||
agg_funcs: List[DecimalFunction] = get_agg_funcs_func()
|
||||
for func in agg_funcs:
|
||||
for col in tb_cols:
|
||||
if col.name_ == '' or func.should_skip_for_decimal([col]):
|
||||
continue
|
||||
func.query_col = col
|
||||
select_expr = func.generate([col])
|
||||
sql = f"select {select_expr} from {dbname}.{tbname}"
|
||||
res = TaosShell().query(sql)
|
||||
if len(res) > 0:
|
||||
res = res[0]
|
||||
func.check_for_agg_func(res, tbname)
|
||||
|
||||
def test_decimal_functions(self):
|
||||
if not test_decimal_funcs:
|
||||
return
|
||||
self.test_decimal_agg_funcs(
|
||||
self.db_name,
|
||||
self.norm_table_name,
|
||||
self.norm_tb_columns,
|
||||
DecimalFunction.get_decimal_agg_funcs,
|
||||
)
|
||||
|
||||
self.test_decimal_last_first_func()
|
||||
|
||||
def test_query_decimal(self):
|
||||
self.test_decimal_operators()
|
||||
self.test_decimal_functions()
|
||||
|
|
Loading…
Reference in New Issue