fix(gpt): fix the bug for algo name case sensitive (#30362)

This commit is contained in:
Haojun Liao 2025-03-23 13:20:24 +08:00 committed by GitHub
parent 374a95dcca
commit 82ffeb82a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 62 additions and 58 deletions

View File

@ -1257,18 +1257,18 @@ int32_t tDeserializeRetrieveIpWhite(void* buf, int32_t bufLen, SRetrieveIpWhiteR
typedef struct { typedef struct {
int32_t dnodeId; int32_t dnodeId;
int64_t analVer; int64_t analVer;
} SRetrieveAnalAlgoReq; } SRetrieveAnalyticsAlgoReq;
typedef struct { typedef struct {
int64_t ver; int64_t ver;
SHashObj* hash; // algoname:algotype -> SAnalUrl SHashObj* hash; // algoname:algotype -> SAnalUrl
} SRetrieveAnalAlgoRsp; } SRetrieveAnalyticAlgoRsp;
int32_t tSerializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq); int32_t tSerializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq);
int32_t tDeserializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq); int32_t tDeserializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq);
int32_t tSerializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp); int32_t tSerializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp);
int32_t tDeserializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp); int32_t tDeserializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp);
void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp* pRsp); void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp* pRsp);
typedef struct { typedef struct {
int8_t alterType; int8_t alterType;

View File

@ -2297,7 +2297,7 @@ _exit:
return code; return code;
} }
int32_t tSerializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) { int32_t tSerializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) {
SEncoder encoder = {0}; SEncoder encoder = {0};
int32_t code = 0; int32_t code = 0;
int32_t lino; int32_t lino;
@ -2319,7 +2319,7 @@ _exit:
return tlen; return tlen;
} }
int32_t tDeserializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) { int32_t tDeserializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) {
SDecoder decoder = {0}; SDecoder decoder = {0};
int32_t code = 0; int32_t code = 0;
int32_t lino; int32_t lino;
@ -2336,7 +2336,7 @@ _exit:
return code; return code;
} }
int32_t tSerializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) { int32_t tSerializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) {
SEncoder encoder = {0}; SEncoder encoder = {0};
int32_t code = 0; int32_t code = 0;
int32_t lino; int32_t lino;
@ -2387,7 +2387,7 @@ _exit:
return tlen; return tlen;
} }
int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) { int32_t tDeserializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) {
if (pRsp->hash == NULL) { if (pRsp->hash == NULL) {
pRsp->hash = taosHashInit(64, MurmurHash3_32, true, HASH_ENTRY_LOCK); pRsp->hash = taosHashInit(64, MurmurHash3_32, true, HASH_ENTRY_LOCK);
if (pRsp->hash == NULL) { if (pRsp->hash == NULL) {
@ -2425,7 +2425,10 @@ int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnal
TAOS_CHECK_EXIT(tDecodeBinaryAlloc(&decoder, (void **)&url.url, NULL) < 0); TAOS_CHECK_EXIT(tDecodeBinaryAlloc(&decoder, (void **)&url.url, NULL) < 0);
} }
TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, name, nameLen, &url, sizeof(SAnalyticsUrl))); char dstName[TSDB_ANALYTIC_ALGO_NAME_LEN] = {0};
strntolower(dstName, name, nameLen);
TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, dstName, nameLen, &url, sizeof(SAnalyticsUrl)));
} }
tEndDecode(&decoder); tEndDecode(&decoder);
@ -2435,7 +2438,7 @@ _exit:
return code; return code;
} }
void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp *pRsp) { void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp *pRsp) {
void *pIter = taosHashIterate(pRsp->hash, NULL); void *pIter = taosHashIterate(pRsp->hash, NULL);
while (pIter != NULL) { while (pIter != NULL) {
SAnalyticsUrl *pUrl = (SAnalyticsUrl *)pIter; SAnalyticsUrl *pUrl = (SAnalyticsUrl *)pIter;

View File

@ -98,15 +98,15 @@ static void dmMayShouldUpdateAnalFunc(SDnodeMgmt *pMgmt, int64_t newVer) {
if (oldVer == newVer) return; if (oldVer == newVer) return;
dDebug("analysis on dnode ver:%" PRId64 ", status ver:%" PRId64, oldVer, newVer); dDebug("analysis on dnode ver:%" PRId64 ", status ver:%" PRId64, oldVer, newVer);
SRetrieveAnalAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer}; SRetrieveAnalyticsAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer};
int32_t contLen = tSerializeRetrieveAnalAlgoReq(NULL, 0, &req); int32_t contLen = tSerializeRetrieveAnalyticAlgoReq(NULL, 0, &req);
if (contLen < 0) { if (contLen < 0) {
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen)); dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));
return; return;
} }
void *pHead = rpcMallocCont(contLen); void *pHead = rpcMallocCont(contLen);
contLen = tSerializeRetrieveAnalAlgoReq(pHead, contLen, &req); contLen = tSerializeRetrieveAnalyticAlgoReq(pHead, contLen, &req);
if (contLen < 0) { if (contLen < 0) {
rpcFreeCont(pHead); rpcFreeCont(pHead);
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen)); dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));

View File

@ -116,13 +116,13 @@ static bool dmIsForbiddenIp(int8_t forbidden, char *user, uint32_t clientIp) {
} }
} }
static void dmUpdateAnalFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) { static void dmUpdateAnalyticFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) {
SRetrieveAnalAlgoRsp rsp = {0}; SRetrieveAnalyticAlgoRsp rsp = {0};
if (tDeserializeRetrieveAnalAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) { if (tDeserializeRetrieveAnalyticAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) {
taosAnalyUpdate(rsp.ver, rsp.hash); taosAnalyUpdate(rsp.ver, rsp.hash);
rsp.hash = NULL; rsp.hash = NULL;
} }
tFreeRetrieveAnalAlgoRsp(&rsp); tFreeRetrieveAnalyticAlgoRsp(&rsp);
rpcFreeCont(pRpc->pCont); rpcFreeCont(pRpc->pCont);
} }
@ -176,7 +176,7 @@ static void dmProcessRpcMsg(SDnode *pDnode, SRpcMsg *pRpc, SEpSet *pEpSet) {
dmUpdateRpcIpWhite(&pDnode->data, pTrans->serverRpc, pRpc); dmUpdateRpcIpWhite(&pDnode->data, pTrans->serverRpc, pRpc);
return; return;
case TDMT_MND_RETRIEVE_ANAL_ALGO_RSP: case TDMT_MND_RETRIEVE_ANAL_ALGO_RSP:
dmUpdateAnalFunc(&pDnode->data, pTrans->serverRpc, pRpc); dmUpdateAnalyticFunc(&pDnode->data, pTrans->serverRpc, pRpc);
return; return;
default: default:
break; break;

View File

@ -847,10 +847,10 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) {
SAnalyticsUrl url; SAnalyticsUrl url;
int32_t nameLen; int32_t nameLen;
char name[TSDB_ANALYTIC_ALGO_KEY_LEN]; char name[TSDB_ANALYTIC_ALGO_KEY_LEN];
SRetrieveAnalAlgoReq req = {0}; SRetrieveAnalyticsAlgoReq req = {0};
SRetrieveAnalAlgoRsp rsp = {0}; SRetrieveAnalyticAlgoRsp rsp = {0};
TAOS_CHECK_GOTO(tDeserializeRetrieveAnalAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER); TAOS_CHECK_GOTO(tDeserializeRetrieveAnalyticAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER);
rsp.ver = sdbGetTableVer(pSdb, SDB_ANODE); rsp.ver = sdbGetTableVer(pSdb, SDB_ANODE);
if (req.analVer != rsp.ver) { if (req.analVer != rsp.ver) {
@ -906,15 +906,15 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) {
} }
} }
int32_t contLen = tSerializeRetrieveAnalAlgoRsp(NULL, 0, &rsp); int32_t contLen = tSerializeRetrieveAnalyticAlgoRsp(NULL, 0, &rsp);
void *pHead = rpcMallocCont(contLen); void *pHead = rpcMallocCont(contLen);
(void)tSerializeRetrieveAnalAlgoRsp(pHead, contLen, &rsp); (void)tSerializeRetrieveAnalyticAlgoRsp(pHead, contLen, &rsp);
pReq->info.rspLen = contLen; pReq->info.rspLen = contLen;
pReq->info.rsp = pHead; pReq->info.rsp = pHead;
_OVER: _OVER:
tFreeRetrieveAnalAlgoRsp(&rsp); tFreeRetrieveAnalyticAlgoRsp(&rsp);
TAOS_RETURN(code); TAOS_RETURN(code);
} }

View File

@ -145,6 +145,7 @@ static int32_t forecastCloseBuf(SForecastSupp* pSupp, const char* id) {
if (!hasWncheck) { if (!hasWncheck) {
qDebug("%s forecast wncheck not found from %s, use default:%" PRId64, id, pSupp->algoOpt, wncheck); qDebug("%s forecast wncheck not found from %s, use default:%" PRId64, id, pSupp->algoOpt, wncheck);
} }
code = taosAnalyBufWriteOptInt(pBuf, "wncheck", wncheck); code = taosAnalyBufWriteOptInt(pBuf, "wncheck", wncheck);
if (code != 0) return code; if (code != 0) return code;

View File

@ -86,11 +86,11 @@ class _ArimaService(AbstractForecastService):
if len(self.list) > 3000: if len(self.list) > 3000:
raise ValueError("number of input data is too large") raise ValueError("number of input data is too large")
if self.fc_rows <= 0: if self.rows <= 0:
raise ValueError("fc rows is not specified yet") raise ValueError("fc rows is not specified yet")
res, mse, model_info = self.__do_forecast_helper(self.fc_rows) res, mse, model_info = self.__do_forecast_helper(self.rows)
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows) insert_ts_list(res, self.start_ts, self.time_step, self.rows)
return { return {
"mse": mse, "mse": mse,

View File

@ -10,14 +10,14 @@ from taosanalytics.service import AbstractForecastService
class _GPTService(AbstractForecastService): class _GPTService(AbstractForecastService):
name = 'TDtsfm_1' name = 'tdtsfm_1'
desc = "internal gpt forecast model based on transformer" desc = "internal gpt forecast model based on transformer"
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.table_name = None self.table_name = None
self.service_host = 'http://127.0.0.1:5000/ds_predict' self.service_host = 'http://127.0.0.1:5000/tdtsfm'
self.headers = {'Content-Type': 'application/json'} self.headers = {'Content-Type': 'application/json'}
self.std = None self.std = None
@ -29,11 +29,11 @@ class _GPTService(AbstractForecastService):
if self.list is None or len(self.list) < self.period: if self.list is None or len(self.list) < self.period:
raise ValueError("number of input data is less than the periods") raise ValueError("number of input data is less than the periods")
if self.fc_rows <= 0: if self.rows <= 0:
raise ValueError("fc rows is not specified yet") raise ValueError("fc rows is not specified yet")
# let's request the gpt service # let's request the gpt service
data = {"input": self.list, 'next_len': self.fc_rows} data = {"input": self.list, 'next_len': self.rows}
try: try:
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers) response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
except Exception as e: except Exception as e:
@ -53,7 +53,7 @@ class _GPTService(AbstractForecastService):
"res": [pred_y] "res": [pred_y]
} }
insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows) insert_ts_list(res["res"], self.start_ts, self.time_step, self.rows)
return res return res

View File

@ -66,11 +66,11 @@ class _HoltWintersService(AbstractForecastService):
if self.list is None or len(self.list) < self.period: if self.list is None or len(self.list) < self.period:
raise ValueError("number of input data is less than the periods") raise ValueError("number of input data is less than the periods")
if self.fc_rows <= 0: if self.rows <= 0:
raise ValueError("fc rows is not specified yet") raise ValueError("fc rows is not specified yet")
res, mse = self.__do_forecast_helper(self.list, self.fc_rows) res, mse = self.__do_forecast_helper(self.list, self.rows)
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows) insert_ts_list(res, self.start_ts, self.time_step, self.rows)
# add the conf range if required # add the conf range if required
return { return {

View File

@ -46,7 +46,7 @@ class _LSTMService(AbstractForecastService):
res = self.model.predict(self.list) res = self.model.predict(self.list)
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows) insert_ts_list(res, self.start_ts, self.time_step, self.rows)
if self.return_conf: if self.return_conf:
res1 = [res.tolist(), res.tolist(), res.tolist()], None res1 = [res.tolist(), res.tolist(), res.tolist()], None

View File

@ -41,7 +41,7 @@ def do_forecast(input_list, ts_list, algo_name, params):
def do_add_fc_params(params, json_obj): def do_add_fc_params(params, json_obj):
""" add params into parameters """ """ add params into parameters """
if "forecast_rows" in json_obj: if "forecast_rows" in json_obj:
params["fc_rows"] = int(json_obj["forecast_rows"]) params["rows"] = int(json_obj["forecast_rows"])
if "start" in json_obj: if "start" in json_obj:
params["start_ts"] = int(json_obj["start"]) params["start_ts"] = int(json_obj["start"])

View File

@ -147,7 +147,7 @@ def handle_forecast_req():
try: try:
res1 = do_forecast(payload[data_index], payload[ts_index], algo, params) res1 = do_forecast(payload[data_index], payload[ts_index], algo, params)
res = {"option": options, "rows": params["fc_rows"]} res = {"option": options, "rows": params["rows"]}
res.update(res1) res.update(res1)
app_logger.log_inst.debug("forecast result: %s", res) app_logger.log_inst.debug("forecast result: %s", res)

View File

@ -77,14 +77,14 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
self.period = 0 self.period = 0
self.start_ts = 0 self.start_ts = 0
self.time_step = 0 self.time_step = 0
self.fc_rows = 0 self.rows = 0
self.return_conf = 1 self.return_conf = 1
self.conf = 0.05 self.conf = 0.05
def set_params(self, params: dict) -> None: def set_params(self, params: dict) -> None:
if not {'start_ts', 'time_step', 'fc_rows'}.issubset(params.keys()): if not {'start_ts', 'time_step', 'rows'}.issubset(params.keys()):
raise ValueError('params are missing, start_ts, time_step, fc_rows are all required') raise ValueError('params are missing, start_ts, time_step, rows are all required')
self.start_ts = int(params['start_ts']) self.start_ts = int(params['start_ts'])
@ -93,9 +93,9 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
if self.time_step <= 0: if self.time_step <= 0:
raise ValueError('time_step should be greater than 0') raise ValueError('time_step should be greater than 0')
self.fc_rows = int(params['fc_rows']) self.rows = int(params['rows'])
if self.fc_rows <= 0: if self.rows <= 0:
raise ValueError('fc rows is not specified yet') raise ValueError('fc rows is not specified yet')
self.period = int(params['period']) if 'period' in params else 0 self.period = int(params['period']) if 'period' in params else 0
@ -113,5 +113,5 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
def get_params(self): def get_params(self):
return { return {
"period": self.period, "start": self.start_ts, "every": self.time_step, "period": self.period, "start": self.start_ts, "every": self.time_step,
"forecast_rows": self.fc_rows, "return_conf": self.return_conf, "conf": self.conf "forecast_rows": self.rows, "return_conf": self.return_conf, "conf": self.conf
} }

View File

@ -41,7 +41,7 @@ class ForecastTest(unittest.TestCase):
s.set_input_list(data, ts) s.set_input_list(data, ts)
self.assertRaises(ValueError, s.execute) self.assertRaises(ValueError, s.execute)
s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30}) s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
r = s.execute() r = s.execute()
draw_fc_results(data, len(r["res"]) > 2, r["res"], len(r["res"][0]), "holtwinters") draw_fc_results(data, len(r["res"]) > 2, r["res"], len(r["res"][0]), "holtwinters")
@ -54,7 +54,7 @@ class ForecastTest(unittest.TestCase):
s.set_input_list(data, ts) s.set_input_list(data, ts)
s.set_params( s.set_params(
{ {
"fc_rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000, "rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000,
"time_step": 86400 * 30, "period": 12 "time_step": 86400 * 30, "period": 12
} }
) )
@ -71,28 +71,28 @@ class ForecastTest(unittest.TestCase):
self.assertRaises(ValueError, s.set_params, {"trend": "mul"}) self.assertRaises(ValueError, s.set_params, {"trend": "mul"})
self.assertRaises(ValueError, s.set_params, {"trend": "mul", "fc_rows": 10}) self.assertRaises(ValueError, s.set_params, {"trend": "mul", "rows": 10})
self.assertRaises(ValueError, s.set_params, {"trend": "multi"}) self.assertRaises(ValueError, s.set_params, {"trend": "multi"})
self.assertRaises(ValueError, s.set_params, {"seasonal": "additive"}) self.assertRaises(ValueError, s.set_params, {"seasonal": "additive"})
self.assertRaises(ValueError, s.set_params, { self.assertRaises(ValueError, s.set_params, {
"fc_rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000, "rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000,
"time_step": 86400 * 30, "period": 12} "time_step": 86400 * 30, "period": 12}
) )
self.assertRaises(ValueError, s.set_params, self.assertRaises(ValueError, s.set_params,
{"fc_rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12} {"rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12}
) )
s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30}) s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
self.assertRaises(ValueError, s.set_params, {"fc_rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30}) self.assertRaises(ValueError, s.set_params, {"rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30})
self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": "aaa", "time_step": "30"}) self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": "aaa", "time_step": "30"})
self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": 171000000, "time_step": 0}) self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": 171000000, "time_step": 0})
def test_arima(self): def test_arima(self):
"""arima algorithm check""" """arima algorithm check"""
@ -103,7 +103,7 @@ class ForecastTest(unittest.TestCase):
self.assertRaises(ValueError, s.execute) self.assertRaises(ValueError, s.execute)
s.set_params( s.set_params(
{"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12, {"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12,
"start_p": 0, "max_p": 10, "start_q": 0, "max_q": 10} "start_p": 0, "max_p": 10, "start_q": 0, "max_q": 10}
) )
r = s.execute() r = s.execute()
@ -120,7 +120,7 @@ class ForecastTest(unittest.TestCase):
# s = loader.get_service("td_gpt_fc") # s = loader.get_service("td_gpt_fc")
# s.set_input_list(data, ts) # s.set_input_list(data, ts)
# #
# s.set_params({"host":'192.168.2.90:5000/ds_predict', 'fc_rows': 10, 'start_ts': 171000000, 'time_step': 86400*30}) # s.set_params({"host":'192.168.2.90:5000/ds_predict', 'rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
# r = s.execute() # r = s.execute()
# #
# rows = len(r["res"][0]) # rows = len(r["res"][0])