From 82ffeb82a95fddb32a794329af609a7619e618de Mon Sep 17 00:00:00 2001 From: Haojun Liao Date: Sun, 23 Mar 2025 13:20:24 +0800 Subject: [PATCH] fix(gpt): fix the bug for algo name case sensitive (#30362) --- include/common/tmsg.h | 14 ++++++------ source/common/src/msg/tmsg.c | 15 ++++++++----- source/dnode/mgmt/mgmt_dnode/src/dmHandle.c | 6 ++--- source/dnode/mgmt/node_mgmt/src/dmTransport.c | 10 ++++----- source/dnode/mnode/impl/src/mndAnode.c | 12 +++++----- source/libs/executor/src/forecastoperator.c | 1 + tools/tdgpt/taosanalytics/algo/fc/arima.py | 6 ++--- tools/tdgpt/taosanalytics/algo/fc/gpt.py | 10 ++++----- .../taosanalytics/algo/fc/holtwinters.py | 6 ++--- tools/tdgpt/taosanalytics/algo/fc/lstm.py | 2 +- tools/tdgpt/taosanalytics/algo/forecast.py | 2 +- tools/tdgpt/taosanalytics/app.py | 2 +- tools/tdgpt/taosanalytics/service.py | 12 +++++----- .../tdgpt/taosanalytics/test/forecast_test.py | 22 +++++++++---------- 14 files changed, 62 insertions(+), 58 deletions(-) diff --git a/include/common/tmsg.h b/include/common/tmsg.h index c295c40c1e..75a800f1c3 100644 --- a/include/common/tmsg.h +++ b/include/common/tmsg.h @@ -1257,18 +1257,18 @@ int32_t tDeserializeRetrieveIpWhite(void* buf, int32_t bufLen, SRetrieveIpWhiteR typedef struct { int32_t dnodeId; int64_t analVer; -} SRetrieveAnalAlgoReq; +} SRetrieveAnalyticsAlgoReq; typedef struct { int64_t ver; SHashObj* hash; // algoname:algotype -> SAnalUrl -} SRetrieveAnalAlgoRsp; +} SRetrieveAnalyticAlgoRsp; -int32_t tSerializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq); -int32_t tDeserializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq); -int32_t tSerializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp); -int32_t tDeserializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp); -void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp* pRsp); +int32_t tSerializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq); +int32_t tDeserializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq); +int32_t tSerializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp); +int32_t tDeserializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp); +void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp* pRsp); typedef struct { int8_t alterType; diff --git a/source/common/src/msg/tmsg.c b/source/common/src/msg/tmsg.c index 5d99ef8fea..930408ef9b 100644 --- a/source/common/src/msg/tmsg.c +++ b/source/common/src/msg/tmsg.c @@ -2297,7 +2297,7 @@ _exit: return code; } -int32_t tSerializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) { +int32_t tSerializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) { SEncoder encoder = {0}; int32_t code = 0; int32_t lino; @@ -2319,7 +2319,7 @@ _exit: return tlen; } -int32_t tDeserializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) { +int32_t tDeserializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) { SDecoder decoder = {0}; int32_t code = 0; int32_t lino; @@ -2336,7 +2336,7 @@ _exit: return code; } -int32_t tSerializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) { +int32_t tSerializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) { SEncoder encoder = {0}; int32_t code = 0; int32_t lino; @@ -2387,7 +2387,7 @@ _exit: return tlen; } -int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) { +int32_t tDeserializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) { if (pRsp->hash == NULL) { pRsp->hash = taosHashInit(64, MurmurHash3_32, true, HASH_ENTRY_LOCK); if (pRsp->hash == NULL) { @@ -2425,7 +2425,10 @@ int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnal TAOS_CHECK_EXIT(tDecodeBinaryAlloc(&decoder, (void **)&url.url, NULL) < 0); } - TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, name, nameLen, &url, sizeof(SAnalyticsUrl))); + char dstName[TSDB_ANALYTIC_ALGO_NAME_LEN] = {0}; + strntolower(dstName, name, nameLen); + + TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, dstName, nameLen, &url, sizeof(SAnalyticsUrl))); } tEndDecode(&decoder); @@ -2435,7 +2438,7 @@ _exit: return code; } -void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp *pRsp) { +void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp *pRsp) { void *pIter = taosHashIterate(pRsp->hash, NULL); while (pIter != NULL) { SAnalyticsUrl *pUrl = (SAnalyticsUrl *)pIter; diff --git a/source/dnode/mgmt/mgmt_dnode/src/dmHandle.c b/source/dnode/mgmt/mgmt_dnode/src/dmHandle.c index fc4ead8973..80f8a749ea 100644 --- a/source/dnode/mgmt/mgmt_dnode/src/dmHandle.c +++ b/source/dnode/mgmt/mgmt_dnode/src/dmHandle.c @@ -98,15 +98,15 @@ static void dmMayShouldUpdateAnalFunc(SDnodeMgmt *pMgmt, int64_t newVer) { if (oldVer == newVer) return; dDebug("analysis on dnode ver:%" PRId64 ", status ver:%" PRId64, oldVer, newVer); - SRetrieveAnalAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer}; - int32_t contLen = tSerializeRetrieveAnalAlgoReq(NULL, 0, &req); + SRetrieveAnalyticsAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer}; + int32_t contLen = tSerializeRetrieveAnalyticAlgoReq(NULL, 0, &req); if (contLen < 0) { dError("failed to serialize analysis function ver request since %s", tstrerror(contLen)); return; } void *pHead = rpcMallocCont(contLen); - contLen = tSerializeRetrieveAnalAlgoReq(pHead, contLen, &req); + contLen = tSerializeRetrieveAnalyticAlgoReq(pHead, contLen, &req); if (contLen < 0) { rpcFreeCont(pHead); dError("failed to serialize analysis function ver request since %s", tstrerror(contLen)); diff --git a/source/dnode/mgmt/node_mgmt/src/dmTransport.c b/source/dnode/mgmt/node_mgmt/src/dmTransport.c index 8ab14cff2f..d89e90bf90 100644 --- a/source/dnode/mgmt/node_mgmt/src/dmTransport.c +++ b/source/dnode/mgmt/node_mgmt/src/dmTransport.c @@ -116,13 +116,13 @@ static bool dmIsForbiddenIp(int8_t forbidden, char *user, uint32_t clientIp) { } } -static void dmUpdateAnalFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) { - SRetrieveAnalAlgoRsp rsp = {0}; - if (tDeserializeRetrieveAnalAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) { +static void dmUpdateAnalyticFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) { + SRetrieveAnalyticAlgoRsp rsp = {0}; + if (tDeserializeRetrieveAnalyticAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) { taosAnalyUpdate(rsp.ver, rsp.hash); rsp.hash = NULL; } - tFreeRetrieveAnalAlgoRsp(&rsp); + tFreeRetrieveAnalyticAlgoRsp(&rsp); rpcFreeCont(pRpc->pCont); } @@ -176,7 +176,7 @@ static void dmProcessRpcMsg(SDnode *pDnode, SRpcMsg *pRpc, SEpSet *pEpSet) { dmUpdateRpcIpWhite(&pDnode->data, pTrans->serverRpc, pRpc); return; case TDMT_MND_RETRIEVE_ANAL_ALGO_RSP: - dmUpdateAnalFunc(&pDnode->data, pTrans->serverRpc, pRpc); + dmUpdateAnalyticFunc(&pDnode->data, pTrans->serverRpc, pRpc); return; default: break; diff --git a/source/dnode/mnode/impl/src/mndAnode.c b/source/dnode/mnode/impl/src/mndAnode.c index 0777c2e247..163e697cc1 100644 --- a/source/dnode/mnode/impl/src/mndAnode.c +++ b/source/dnode/mnode/impl/src/mndAnode.c @@ -847,10 +847,10 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) { SAnalyticsUrl url; int32_t nameLen; char name[TSDB_ANALYTIC_ALGO_KEY_LEN]; - SRetrieveAnalAlgoReq req = {0}; - SRetrieveAnalAlgoRsp rsp = {0}; + SRetrieveAnalyticsAlgoReq req = {0}; + SRetrieveAnalyticAlgoRsp rsp = {0}; - TAOS_CHECK_GOTO(tDeserializeRetrieveAnalAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER); + TAOS_CHECK_GOTO(tDeserializeRetrieveAnalyticAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER); rsp.ver = sdbGetTableVer(pSdb, SDB_ANODE); if (req.analVer != rsp.ver) { @@ -906,15 +906,15 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) { } } - int32_t contLen = tSerializeRetrieveAnalAlgoRsp(NULL, 0, &rsp); + int32_t contLen = tSerializeRetrieveAnalyticAlgoRsp(NULL, 0, &rsp); void *pHead = rpcMallocCont(contLen); - (void)tSerializeRetrieveAnalAlgoRsp(pHead, contLen, &rsp); + (void)tSerializeRetrieveAnalyticAlgoRsp(pHead, contLen, &rsp); pReq->info.rspLen = contLen; pReq->info.rsp = pHead; _OVER: - tFreeRetrieveAnalAlgoRsp(&rsp); + tFreeRetrieveAnalyticAlgoRsp(&rsp); TAOS_RETURN(code); } diff --git a/source/libs/executor/src/forecastoperator.c b/source/libs/executor/src/forecastoperator.c index 25052af523..ad7f37cad9 100644 --- a/source/libs/executor/src/forecastoperator.c +++ b/source/libs/executor/src/forecastoperator.c @@ -145,6 +145,7 @@ static int32_t forecastCloseBuf(SForecastSupp* pSupp, const char* id) { if (!hasWncheck) { qDebug("%s forecast wncheck not found from %s, use default:%" PRId64, id, pSupp->algoOpt, wncheck); } + code = taosAnalyBufWriteOptInt(pBuf, "wncheck", wncheck); if (code != 0) return code; diff --git a/tools/tdgpt/taosanalytics/algo/fc/arima.py b/tools/tdgpt/taosanalytics/algo/fc/arima.py index 787cb757df..79d7136440 100644 --- a/tools/tdgpt/taosanalytics/algo/fc/arima.py +++ b/tools/tdgpt/taosanalytics/algo/fc/arima.py @@ -86,11 +86,11 @@ class _ArimaService(AbstractForecastService): if len(self.list) > 3000: raise ValueError("number of input data is too large") - if self.fc_rows <= 0: + if self.rows <= 0: raise ValueError("fc rows is not specified yet") - res, mse, model_info = self.__do_forecast_helper(self.fc_rows) - insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows) + res, mse, model_info = self.__do_forecast_helper(self.rows) + insert_ts_list(res, self.start_ts, self.time_step, self.rows) return { "mse": mse, diff --git a/tools/tdgpt/taosanalytics/algo/fc/gpt.py b/tools/tdgpt/taosanalytics/algo/fc/gpt.py index 6a6e13edb2..65fa0240c1 100644 --- a/tools/tdgpt/taosanalytics/algo/fc/gpt.py +++ b/tools/tdgpt/taosanalytics/algo/fc/gpt.py @@ -10,14 +10,14 @@ from taosanalytics.service import AbstractForecastService class _GPTService(AbstractForecastService): - name = 'TDtsfm_1' + name = 'tdtsfm_1' desc = "internal gpt forecast model based on transformer" def __init__(self): super().__init__() self.table_name = None - self.service_host = 'http://127.0.0.1:5000/ds_predict' + self.service_host = 'http://127.0.0.1:5000/tdtsfm' self.headers = {'Content-Type': 'application/json'} self.std = None @@ -29,11 +29,11 @@ class _GPTService(AbstractForecastService): if self.list is None or len(self.list) < self.period: raise ValueError("number of input data is less than the periods") - if self.fc_rows <= 0: + if self.rows <= 0: raise ValueError("fc rows is not specified yet") # let's request the gpt service - data = {"input": self.list, 'next_len': self.fc_rows} + data = {"input": self.list, 'next_len': self.rows} try: response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers) except Exception as e: @@ -53,7 +53,7 @@ class _GPTService(AbstractForecastService): "res": [pred_y] } - insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows) + insert_ts_list(res["res"], self.start_ts, self.time_step, self.rows) return res diff --git a/tools/tdgpt/taosanalytics/algo/fc/holtwinters.py b/tools/tdgpt/taosanalytics/algo/fc/holtwinters.py index d8225eaa5a..24aea44fdb 100644 --- a/tools/tdgpt/taosanalytics/algo/fc/holtwinters.py +++ b/tools/tdgpt/taosanalytics/algo/fc/holtwinters.py @@ -66,11 +66,11 @@ class _HoltWintersService(AbstractForecastService): if self.list is None or len(self.list) < self.period: raise ValueError("number of input data is less than the periods") - if self.fc_rows <= 0: + if self.rows <= 0: raise ValueError("fc rows is not specified yet") - res, mse = self.__do_forecast_helper(self.list, self.fc_rows) - insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows) + res, mse = self.__do_forecast_helper(self.list, self.rows) + insert_ts_list(res, self.start_ts, self.time_step, self.rows) # add the conf range if required return { diff --git a/tools/tdgpt/taosanalytics/algo/fc/lstm.py b/tools/tdgpt/taosanalytics/algo/fc/lstm.py index 5edae7fc9f..72534ab6ab 100644 --- a/tools/tdgpt/taosanalytics/algo/fc/lstm.py +++ b/tools/tdgpt/taosanalytics/algo/fc/lstm.py @@ -46,7 +46,7 @@ class _LSTMService(AbstractForecastService): res = self.model.predict(self.list) - insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows) + insert_ts_list(res, self.start_ts, self.time_step, self.rows) if self.return_conf: res1 = [res.tolist(), res.tolist(), res.tolist()], None diff --git a/tools/tdgpt/taosanalytics/algo/forecast.py b/tools/tdgpt/taosanalytics/algo/forecast.py index 4baa92fe15..5c681bf3b3 100644 --- a/tools/tdgpt/taosanalytics/algo/forecast.py +++ b/tools/tdgpt/taosanalytics/algo/forecast.py @@ -41,7 +41,7 @@ def do_forecast(input_list, ts_list, algo_name, params): def do_add_fc_params(params, json_obj): """ add params into parameters """ if "forecast_rows" in json_obj: - params["fc_rows"] = int(json_obj["forecast_rows"]) + params["rows"] = int(json_obj["forecast_rows"]) if "start" in json_obj: params["start_ts"] = int(json_obj["start"]) diff --git a/tools/tdgpt/taosanalytics/app.py b/tools/tdgpt/taosanalytics/app.py index cff5f74447..e2196062fb 100644 --- a/tools/tdgpt/taosanalytics/app.py +++ b/tools/tdgpt/taosanalytics/app.py @@ -147,7 +147,7 @@ def handle_forecast_req(): try: res1 = do_forecast(payload[data_index], payload[ts_index], algo, params) - res = {"option": options, "rows": params["fc_rows"]} + res = {"option": options, "rows": params["rows"]} res.update(res1) app_logger.log_inst.debug("forecast result: %s", res) diff --git a/tools/tdgpt/taosanalytics/service.py b/tools/tdgpt/taosanalytics/service.py index b79d21501a..9e960f3e58 100644 --- a/tools/tdgpt/taosanalytics/service.py +++ b/tools/tdgpt/taosanalytics/service.py @@ -77,14 +77,14 @@ class AbstractForecastService(AbstractAnalyticsService, ABC): self.period = 0 self.start_ts = 0 self.time_step = 0 - self.fc_rows = 0 + self.rows = 0 self.return_conf = 1 self.conf = 0.05 def set_params(self, params: dict) -> None: - if not {'start_ts', 'time_step', 'fc_rows'}.issubset(params.keys()): - raise ValueError('params are missing, start_ts, time_step, fc_rows are all required') + if not {'start_ts', 'time_step', 'rows'}.issubset(params.keys()): + raise ValueError('params are missing, start_ts, time_step, rows are all required') self.start_ts = int(params['start_ts']) @@ -93,9 +93,9 @@ class AbstractForecastService(AbstractAnalyticsService, ABC): if self.time_step <= 0: raise ValueError('time_step should be greater than 0') - self.fc_rows = int(params['fc_rows']) + self.rows = int(params['rows']) - if self.fc_rows <= 0: + if self.rows <= 0: raise ValueError('fc rows is not specified yet') self.period = int(params['period']) if 'period' in params else 0 @@ -113,5 +113,5 @@ class AbstractForecastService(AbstractAnalyticsService, ABC): def get_params(self): return { "period": self.period, "start": self.start_ts, "every": self.time_step, - "forecast_rows": self.fc_rows, "return_conf": self.return_conf, "conf": self.conf + "forecast_rows": self.rows, "return_conf": self.return_conf, "conf": self.conf } diff --git a/tools/tdgpt/taosanalytics/test/forecast_test.py b/tools/tdgpt/taosanalytics/test/forecast_test.py index 4b2368c6ba..9e417d9263 100644 --- a/tools/tdgpt/taosanalytics/test/forecast_test.py +++ b/tools/tdgpt/taosanalytics/test/forecast_test.py @@ -41,7 +41,7 @@ class ForecastTest(unittest.TestCase): s.set_input_list(data, ts) self.assertRaises(ValueError, s.execute) - s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30}) + s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30}) r = s.execute() draw_fc_results(data, len(r["res"]) > 2, r["res"], len(r["res"][0]), "holtwinters") @@ -54,7 +54,7 @@ class ForecastTest(unittest.TestCase): s.set_input_list(data, ts) s.set_params( { - "fc_rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000, + "rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000, "time_step": 86400 * 30, "period": 12 } ) @@ -71,28 +71,28 @@ class ForecastTest(unittest.TestCase): self.assertRaises(ValueError, s.set_params, {"trend": "mul"}) - self.assertRaises(ValueError, s.set_params, {"trend": "mul", "fc_rows": 10}) + self.assertRaises(ValueError, s.set_params, {"trend": "mul", "rows": 10}) self.assertRaises(ValueError, s.set_params, {"trend": "multi"}) self.assertRaises(ValueError, s.set_params, {"seasonal": "additive"}) self.assertRaises(ValueError, s.set_params, { - "fc_rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000, + "rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000, "time_step": 86400 * 30, "period": 12} ) self.assertRaises(ValueError, s.set_params, - {"fc_rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12} + {"rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12} ) - s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30}) + s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30}) - self.assertRaises(ValueError, s.set_params, {"fc_rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30}) + self.assertRaises(ValueError, s.set_params, {"rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30}) - self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": "aaa", "time_step": "30"}) + self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": "aaa", "time_step": "30"}) - self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": 171000000, "time_step": 0}) + self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": 171000000, "time_step": 0}) def test_arima(self): """arima algorithm check""" @@ -103,7 +103,7 @@ class ForecastTest(unittest.TestCase): self.assertRaises(ValueError, s.execute) s.set_params( - {"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12, + {"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12, "start_p": 0, "max_p": 10, "start_q": 0, "max_q": 10} ) r = s.execute() @@ -120,7 +120,7 @@ class ForecastTest(unittest.TestCase): # s = loader.get_service("td_gpt_fc") # s.set_input_list(data, ts) # - # s.set_params({"host":'192.168.2.90:5000/ds_predict', 'fc_rows': 10, 'start_ts': 171000000, 'time_step': 86400*30}) + # s.set_params({"host":'192.168.2.90:5000/ds_predict', 'rows': 10, 'start_ts': 171000000, 'time_step': 86400*30}) # r = s.execute() # # rows = len(r["res"][0])