fix(decimal): fix decimal tests

This commit is contained in:
wangjiaming0909 2025-03-11 09:18:12 +08:00
parent 2526fb8448
commit 97412e2136
7 changed files with 39 additions and 22 deletions

View File

@ -184,6 +184,8 @@ typedef struct SColumnDataAgg {
} SColumnDataAgg;
#pragma pack(pop)
#define DECIMAL_AGG_FLAG 0x80000000
#define COL_AGG_GET_SUM_PTR(pAggs, dataType) \
(!IS_DECIMAL_TYPE(dataType) ? (void*)&pAggs->sum : (void*)pAggs->decimal128Sum)

View File

@ -551,9 +551,9 @@ int32_t setResSchemaInfo(SReqResultInfo* pResInfo, const SSchema* pSchema, int32
if (pSchema[i].type == TSDB_DATA_TYPE_VARCHAR || pSchema[i].type == TSDB_DATA_TYPE_VARBINARY ||
pSchema[i].type == TSDB_DATA_TYPE_GEOMETRY) {
pResInfo->userFields[i].bytes -= VARSTR_HEADER_SIZE;
pResInfo->fields[i].bytes = pResInfo->userFields[i].bytes -= VARSTR_HEADER_SIZE;
} else if (pSchema[i].type == TSDB_DATA_TYPE_NCHAR || pSchema[i].type == TSDB_DATA_TYPE_JSON) {
pResInfo->userFields[i].bytes = (pResInfo->userFields[i].bytes - VARSTR_HEADER_SIZE) / TSDB_NCHAR_SIZE;
pResInfo->fields[i].bytes = pResInfo->userFields[i].bytes = (pResInfo->userFields[i].bytes - VARSTR_HEADER_SIZE) / TSDB_NCHAR_SIZE;
} else if (IS_DECIMAL_TYPE(pSchema[i].type) && pExtSchema) {
decimalFromTypeMod(pExtSchema[i].typeMod, &pResInfo->fields[i].precision, &pResInfo->fields[i].scale);
}

View File

@ -4242,7 +4242,7 @@ static FORCE_INLINE void tColDataCalcSMADecimal64Type(SColData* pColData, SColum
*pMax = DECIMAL64_MIN;
*pMin = DECIMAL64_MAX;
pAggs->numOfNull = 0;
pAggs->colId |= 0x80000000; // TODO wjm define it
pAggs->colId |= DECIMAL_AGG_FLAG;
Decimal64 *pVal = NULL;
const SDecimalOps *pSumOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);
@ -4278,7 +4278,7 @@ static FORCE_INLINE void tColDataCalcSMADecimal128Type(SColData* pColData, SColu
*pMax = DECIMAL128_MIN;
*pMin = DECIMAL128_MAX;
pAggs->numOfNull = 0;
pAggs->colId |= 0x80000000; // TODO wjm define it
pAggs->colId |= DECIMAL_AGG_FLAG;
Decimal128 *pVal = NULL;
const SDecimalOps *pOps = getDecimalOps(TSDB_DATA_TYPE_DECIMAL);

View File

@ -1284,7 +1284,6 @@ static int32_t mndBuildStbFromAlter(SStbObj *pStb, SStbObj *pDst, SMCreateStbReq
} else {
p->alg = pField->compress;
}
// TODO wjm test it with tmq
if (pField->flags & COL_HAS_TYPE_MOD) {
pDst->pExtSchemas[i].typeMod = pField->typeMod;
}

View File

@ -1576,7 +1576,7 @@ int32_t tGetDiskDataHdr(SBufferReader *br, SDiskDataHdr *pHdr) {
int32_t tPutColumnDataAgg(SBuffer *buffer, SColumnDataAgg *pColAgg) {
int32_t code;
if (pColAgg->colId & 0x80000000) {
if (pColAgg->colId & DECIMAL_AGG_FLAG) {
if ((code = tBufferPutI32v(buffer, pColAgg->colId))) return code;
if ((code = tBufferPutI16v(buffer, pColAgg->numOfNull))) return code;
if ((code = tBufferPutU64(buffer, pColAgg->decimal128Sum[0]))) return code;
@ -1602,7 +1602,7 @@ int32_t tGetColumnDataAgg(SBufferReader *br, SColumnDataAgg *pColAgg) {
if ((code = tBufferGetI32v(br, &pColAgg->colId))) return code;
if ((code = tBufferGetI16v(br, &pColAgg->numOfNull))) return code;
if (pColAgg->colId & 0x80000000) {
if (pColAgg->colId & DECIMAL_AGG_FLAG) {
pColAgg->colId &= 0xFFFF;
if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[0]))) return code;
if ((code = tBufferGetU64(br, &pColAgg->decimal128Sum[1]))) return code;

View File

@ -536,7 +536,6 @@ static const int32_t bitsForNumDigits[] = {0, 4, 7, 10, 14, 17, 20, 24, 2
44, 47, 50, 54, 57, 60, 64, 67, 70, 74, 77, 80, 84,
87, 90, 94, 97, 100, 103, 107, 110, 113, 117, 120, 123, 127};
// TODO wjm pre define it?? actually, its MAX_INTEGER, not MAX
#define DECIMAL128_GET_MAX(precision, pMax) \
do { \
*(pMax) = SCALE_MULTIPLIER_128[precision]; \
@ -732,7 +731,7 @@ static int32_t decimal128ToStr(const DecimalType* pInt, uint8_t scale, char* pBu
TAOS_STRNCAT(pBuf, "0", 2);
}
if (scale > 0) {
static char format[64] = "0000000000000000000000000000000000000000";
static const char *format = "0000000000000000000000000000000000000000";
TAOS_STRNCAT(pBuf, ".", 2);
if (wholeLen < 0) TAOS_STRNCAT(pBuf, format, TABS(wholeLen));
TAOS_STRNCAT(pBuf, buf + TMAX(0, wholeLen), scale);
@ -754,7 +753,7 @@ int32_t decimalToStr(const DecimalType* pDec, int8_t dataType, int8_t precision,
return 0;
}
static void decimalAddLargePositive(Decimal* pX, const SDataType* pXT, const Decimal* pY, const SDataType* pYT,
static int32_t decimalAddLargePositive(Decimal* pX, const SDataType* pXT, const Decimal* pY, const SDataType* pYT,
const SDataType* pOT) {
Decimal wholeX = *pX, wholeY = *pY, fracX = {0}, fracY = {0};
decimal128Divide(&wholeX, &SCALE_MULTIPLIER_128[pXT->scale], WORD_NUM(Decimal), &fracX);
@ -776,11 +775,18 @@ static void decimalAddLargePositive(Decimal* pX, const SDataType* pXT, const Dec
}
decimal128ScaleDown(&right, maxScale - pOT->scale, true);
if (decimal128AddCheckOverflow(&wholeX, &wholeY, WORD_NUM(Decimal))) {
return TSDB_CODE_DECIMAL_OVERFLOW;
}
decimal128Add(&wholeX, &wholeY, WORD_NUM(Decimal));
if (decimal128AddCheckOverflow(&wholeX, &carry, WORD_NUM(Decimal))) {
return TSDB_CODE_DECIMAL_OVERFLOW;
}
decimal128Add(&wholeX, &carry, WORD_NUM(Decimal));
decimal128Multiply(&wholeX, &SCALE_MULTIPLIER_128[pOT->scale], WORD_NUM(Decimal));
decimal128Add(&wholeX, &right, WORD_NUM(Decimal));
*pX = wholeX;
return 0;
}
static void decimalAddLargeNegative(Decimal* pX, const SDataType* pXT, const Decimal* pY, const SDataType* pYT,
@ -810,8 +816,9 @@ static void decimalAddLargeNegative(Decimal* pX, const SDataType* pXT, const Dec
*pX = wholeX;
}
static void decimalAdd(Decimal* pX, const SDataType* pXT, const Decimal* pY, const SDataType* pYT,
static int32_t decimalAdd(Decimal* pX, const SDataType* pXT, const Decimal* pY, const SDataType* pYT,
const SDataType* pOT) {
int32_t code = 0;
if (pOT->precision < TSDB_DECIMAL_MAX_PRECISION) {
uint8_t maxScale = TMAX(pXT->scale, pYT->scale);
Decimal tmpY = *pY;
@ -821,17 +828,18 @@ static void decimalAdd(Decimal* pX, const SDataType* pXT, const Decimal* pY, con
} else {
int8_t signX = DECIMAL128_SIGN(pX), signY = DECIMAL128_SIGN(pY);
if (signX == 1 && signY == 1) {
decimalAddLargePositive(pX, pXT, pY, pYT, pOT);
code = decimalAddLargePositive(pX, pXT, pY, pYT, pOT);
} else if (signX == -1 && signY == -1) {
decimal128Negate(pX);
Decimal y = *pY;
decimal128Negate(&y);
decimalAddLargePositive(pX, pXT, &y, pYT, pOT);
code = decimalAddLargePositive(pX, pXT, &y, pYT, pOT);
decimal128Negate(pX);
} else {
decimalAddLargeNegative(pX, pXT, pY, pYT, pOT);
}
}
return code;
}
static void makeInt256FromDecimal128(Int256* pTarget, const Decimal128* pDec) {
@ -952,7 +960,7 @@ static int32_t decimalDivide(Decimal* pX, const SDataType* pXT, const Decimal* p
int8_t deltaScale = pOT->scale + pYT->scale - pXT->scale;
Decimal xTmp = *pX;
decimal128Abs(&xTmp); // TODO wjm test decimal64 / decimal64
decimal128Abs(&xTmp);
int32_t bitsOccupied = 128 - decimal128CountLeadingBinaryZeros(&xTmp);
if (bitsOccupied + bitsForNumDigits[deltaScale] <= 127) {
xTmp = *pX;
@ -1036,7 +1044,6 @@ static int32_t decimalMod(Decimal* pX, const SDataType* pXT, const Decimal* pY,
int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pRightT, const SDataType* pOutT,
const void* pLeftData, const void* pRightData, void* pOutputData) {
int32_t code = 0;
// TODO wjm if output precision <= 18, no need to convert to decimal128
Decimal left = {0}, right = {0};
SDataType lt = {.type = TSDB_DATA_TYPE_DECIMAL,
@ -1050,7 +1057,7 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR
if (pRightT) rt.scale = pRightT->scale;
if (TSDB_DATA_TYPE_DECIMAL != pLeftT->type) {
code = convertToDecimal(pLeftData, pLeftT, &left, &lt);
if (TSDB_CODE_SUCCESS != code) return code; // TODO add some logs here
if (TSDB_CODE_SUCCESS != code) return code;
} else {
left = *(Decimal*)pLeftData;
}
@ -1070,11 +1077,11 @@ int32_t decimalOp(EOperatorType op, const SDataType* pLeftT, const SDataType* pR
switch (op) {
case OP_TYPE_ADD:
// TODO wjm check overflow
decimalAdd(&left, &lt, &right, &rt, pOutT);
code = decimalAdd(&left, &lt, &right, &rt, pOutT);
break;
case OP_TYPE_SUB:
decimal128Negate(&right);
decimalAdd(&left, &lt, &right, &rt, pOutT);
code = decimalAdd(&left, &lt, &right, &rt, pOutT);
break;
case OP_TYPE_MULTI:
code = decimalMultiply(&left, &lt, &right, &rt, pOutT);
@ -1339,7 +1346,7 @@ static int32_t decimal64FromDecimal64(DecimalType* pDec, uint8_t prec, uint8_t s
static int64_t int64FromDecimal128(const DecimalType* pDec, uint8_t prec, uint8_t scale) {
Decimal128 rounded = *(Decimal128*)pDec;
bool overflow = false; // TODO wjm pass out the overflow??
bool overflow = false;
decimal128RoundWithPositiveScale(&rounded, prec, scale, prec, 0, ROUND_TYPE_HALF_ROUND_UP, &overflow);
if (overflow) {
return 0;
@ -1372,7 +1379,7 @@ static uint64_t uint64FromDecimal128(const DecimalType* pDec, uint8_t prec, uint
}
static int32_t decimal128FromInt64(DecimalType* pDec, uint8_t prec, uint8_t scale, int64_t val) {
if (prec - scale <= 18) { // TODO wjm test int64 with 19 digits.
if (prec - scale <= 18) {
Decimal64 max = {0};
DECIMAL64_GET_MAX(prec - scale, &max);
if (DECIMAL64_GET_VALUE(&max) < val || -DECIMAL64_GET_VALUE(&max) > val) return TSDB_CODE_DECIMAL_OVERFLOW;

View File

@ -1385,8 +1385,8 @@ TEST_F(DecimalTest, api_taos_fetch_rows) {
const char* user = "root";
const char* passwd = "taosdata";
const char* db = "test_api";
const char* create_tb = "create table if not exists test_api.nt(ts timestamp, c1 decimal(10, 2), c2 decimal(38, 10))";
const char* sql = "select c1, c2 from test_api.nt";
const char* create_tb = "create table if not exists test_api.nt(ts timestamp, c1 decimal(10, 2), c2 decimal(38, 10), c3 varchar(255))";
const char* sql = "select c1, c2,c3 from test_api.nt";
const char* sql_insert = "insert into test_api.nt values(now, 123456.123, 98472981092.1209111)";
TAOS* pTaos = taos_connect(host, user, passwd, NULL, 0);
@ -1423,6 +1423,8 @@ TEST_F(DecimalTest, api_taos_fetch_rows) {
ASSERT_EQ(fields_e[0].scale, 2);
ASSERT_EQ(fields_e[1].precision, 38);
ASSERT_EQ(fields_e[1].scale, 10);
ASSERT_EQ(fields_e[2].type, TSDB_DATA_TYPE_VARCHAR);
ASSERT_EQ(fields_e[2].bytes, 255);
taos_free_result(res);
res = taos_query(pTaos, sql);
@ -1491,10 +1493,17 @@ TEST(decimal, test_add_check_overflow) {
Numeric<64> dec64 = {18, 2, "123.12"};
bool overflow = decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64));
ASSERT_TRUE(overflow);
auto ret = dec128 + dec64;
dec128 = {38, 10, "-9999999999999999999999999999.9999999999"};
ASSERT_FALSE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64)));
dec64 = {18, 2, "-123.1"};
ASSERT_TRUE(decimal128AddCheckOverflow((Decimal128*)&dec128.dec(), &dec64.dec(), WORD_NUM(Decimal64)));
dec128 = {38, 0, "99999999999999999999999999999999999999"};
dec64= {18, 0, "123"};
Numeric<128> dec128_2 = {38, 2, "999999999999999999999999999999999999.99"};
ASSERT_OVERFLOW(dec128 + dec128_2);
ASSERT_OVERFLOW(dec128 + dec64);
}
int main(int argc, char** argv) {