fix(gpt): fix the bug for algo name case sensitive (#30362)

This commit is contained in:
Haojun Liao 2025-03-23 13:20:24 +08:00 committed by GitHub
parent 374a95dcca
commit 82ffeb82a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 62 additions and 58 deletions

View File

@ -1257,18 +1257,18 @@ int32_t tDeserializeRetrieveIpWhite(void* buf, int32_t bufLen, SRetrieveIpWhiteR
typedef struct {
int32_t dnodeId;
int64_t analVer;
} SRetrieveAnalAlgoReq;
} SRetrieveAnalyticsAlgoReq;
typedef struct {
int64_t ver;
SHashObj* hash; // algoname:algotype -> SAnalUrl
} SRetrieveAnalAlgoRsp;
} SRetrieveAnalyticAlgoRsp;
int32_t tSerializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq);
int32_t tDeserializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq);
int32_t tSerializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp);
int32_t tDeserializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp);
void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp* pRsp);
int32_t tSerializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq);
int32_t tDeserializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq);
int32_t tSerializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp);
int32_t tDeserializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp);
void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp* pRsp);
typedef struct {
int8_t alterType;

View File

@ -2297,7 +2297,7 @@ _exit:
return code;
}
int32_t tSerializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) {
int32_t tSerializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) {
SEncoder encoder = {0};
int32_t code = 0;
int32_t lino;
@ -2319,7 +2319,7 @@ _exit:
return tlen;
}
int32_t tDeserializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) {
int32_t tDeserializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) {
SDecoder decoder = {0};
int32_t code = 0;
int32_t lino;
@ -2336,7 +2336,7 @@ _exit:
return code;
}
int32_t tSerializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) {
int32_t tSerializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) {
SEncoder encoder = {0};
int32_t code = 0;
int32_t lino;
@ -2387,7 +2387,7 @@ _exit:
return tlen;
}
int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) {
int32_t tDeserializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) {
if (pRsp->hash == NULL) {
pRsp->hash = taosHashInit(64, MurmurHash3_32, true, HASH_ENTRY_LOCK);
if (pRsp->hash == NULL) {
@ -2425,7 +2425,10 @@ int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnal
TAOS_CHECK_EXIT(tDecodeBinaryAlloc(&decoder, (void **)&url.url, NULL) < 0);
}
TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, name, nameLen, &url, sizeof(SAnalyticsUrl)));
char dstName[TSDB_ANALYTIC_ALGO_NAME_LEN] = {0};
strntolower(dstName, name, nameLen);
TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, dstName, nameLen, &url, sizeof(SAnalyticsUrl)));
}
tEndDecode(&decoder);
@ -2435,7 +2438,7 @@ _exit:
return code;
}
void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp *pRsp) {
void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp *pRsp) {
void *pIter = taosHashIterate(pRsp->hash, NULL);
while (pIter != NULL) {
SAnalyticsUrl *pUrl = (SAnalyticsUrl *)pIter;

View File

@ -98,15 +98,15 @@ static void dmMayShouldUpdateAnalFunc(SDnodeMgmt *pMgmt, int64_t newVer) {
if (oldVer == newVer) return;
dDebug("analysis on dnode ver:%" PRId64 ", status ver:%" PRId64, oldVer, newVer);
SRetrieveAnalAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer};
int32_t contLen = tSerializeRetrieveAnalAlgoReq(NULL, 0, &req);
SRetrieveAnalyticsAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer};
int32_t contLen = tSerializeRetrieveAnalyticAlgoReq(NULL, 0, &req);
if (contLen < 0) {
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));
return;
}
void *pHead = rpcMallocCont(contLen);
contLen = tSerializeRetrieveAnalAlgoReq(pHead, contLen, &req);
contLen = tSerializeRetrieveAnalyticAlgoReq(pHead, contLen, &req);
if (contLen < 0) {
rpcFreeCont(pHead);
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));

View File

@ -116,13 +116,13 @@ static bool dmIsForbiddenIp(int8_t forbidden, char *user, uint32_t clientIp) {
}
}
static void dmUpdateAnalFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) {
SRetrieveAnalAlgoRsp rsp = {0};
if (tDeserializeRetrieveAnalAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) {
static void dmUpdateAnalyticFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) {
SRetrieveAnalyticAlgoRsp rsp = {0};
if (tDeserializeRetrieveAnalyticAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) {
taosAnalyUpdate(rsp.ver, rsp.hash);
rsp.hash = NULL;
}
tFreeRetrieveAnalAlgoRsp(&rsp);
tFreeRetrieveAnalyticAlgoRsp(&rsp);
rpcFreeCont(pRpc->pCont);
}
@ -176,7 +176,7 @@ static void dmProcessRpcMsg(SDnode *pDnode, SRpcMsg *pRpc, SEpSet *pEpSet) {
dmUpdateRpcIpWhite(&pDnode->data, pTrans->serverRpc, pRpc);
return;
case TDMT_MND_RETRIEVE_ANAL_ALGO_RSP:
dmUpdateAnalFunc(&pDnode->data, pTrans->serverRpc, pRpc);
dmUpdateAnalyticFunc(&pDnode->data, pTrans->serverRpc, pRpc);
return;
default:
break;

View File

@ -847,10 +847,10 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) {
SAnalyticsUrl url;
int32_t nameLen;
char name[TSDB_ANALYTIC_ALGO_KEY_LEN];
SRetrieveAnalAlgoReq req = {0};
SRetrieveAnalAlgoRsp rsp = {0};
SRetrieveAnalyticsAlgoReq req = {0};
SRetrieveAnalyticAlgoRsp rsp = {0};
TAOS_CHECK_GOTO(tDeserializeRetrieveAnalAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER);
TAOS_CHECK_GOTO(tDeserializeRetrieveAnalyticAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER);
rsp.ver = sdbGetTableVer(pSdb, SDB_ANODE);
if (req.analVer != rsp.ver) {
@ -906,15 +906,15 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) {
}
}
int32_t contLen = tSerializeRetrieveAnalAlgoRsp(NULL, 0, &rsp);
int32_t contLen = tSerializeRetrieveAnalyticAlgoRsp(NULL, 0, &rsp);
void *pHead = rpcMallocCont(contLen);
(void)tSerializeRetrieveAnalAlgoRsp(pHead, contLen, &rsp);
(void)tSerializeRetrieveAnalyticAlgoRsp(pHead, contLen, &rsp);
pReq->info.rspLen = contLen;
pReq->info.rsp = pHead;
_OVER:
tFreeRetrieveAnalAlgoRsp(&rsp);
tFreeRetrieveAnalyticAlgoRsp(&rsp);
TAOS_RETURN(code);
}

View File

@ -145,6 +145,7 @@ static int32_t forecastCloseBuf(SForecastSupp* pSupp, const char* id) {
if (!hasWncheck) {
qDebug("%s forecast wncheck not found from %s, use default:%" PRId64, id, pSupp->algoOpt, wncheck);
}
code = taosAnalyBufWriteOptInt(pBuf, "wncheck", wncheck);
if (code != 0) return code;

View File

@ -86,11 +86,11 @@ class _ArimaService(AbstractForecastService):
if len(self.list) > 3000:
raise ValueError("number of input data is too large")
if self.fc_rows <= 0:
if self.rows <= 0:
raise ValueError("fc rows is not specified yet")
res, mse, model_info = self.__do_forecast_helper(self.fc_rows)
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
res, mse, model_info = self.__do_forecast_helper(self.rows)
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
return {
"mse": mse,

View File

@ -10,14 +10,14 @@ from taosanalytics.service import AbstractForecastService
class _GPTService(AbstractForecastService):
name = 'TDtsfm_1'
name = 'tdtsfm_1'
desc = "internal gpt forecast model based on transformer"
def __init__(self):
super().__init__()
self.table_name = None
self.service_host = 'http://127.0.0.1:5000/ds_predict'
self.service_host = 'http://127.0.0.1:5000/tdtsfm'
self.headers = {'Content-Type': 'application/json'}
self.std = None
@ -29,11 +29,11 @@ class _GPTService(AbstractForecastService):
if self.list is None or len(self.list) < self.period:
raise ValueError("number of input data is less than the periods")
if self.fc_rows <= 0:
if self.rows <= 0:
raise ValueError("fc rows is not specified yet")
# let's request the gpt service
data = {"input": self.list, 'next_len': self.fc_rows}
data = {"input": self.list, 'next_len': self.rows}
try:
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
except Exception as e:
@ -53,7 +53,7 @@ class _GPTService(AbstractForecastService):
"res": [pred_y]
}
insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows)
insert_ts_list(res["res"], self.start_ts, self.time_step, self.rows)
return res

View File

@ -66,11 +66,11 @@ class _HoltWintersService(AbstractForecastService):
if self.list is None or len(self.list) < self.period:
raise ValueError("number of input data is less than the periods")
if self.fc_rows <= 0:
if self.rows <= 0:
raise ValueError("fc rows is not specified yet")
res, mse = self.__do_forecast_helper(self.list, self.fc_rows)
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
res, mse = self.__do_forecast_helper(self.list, self.rows)
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
# add the conf range if required
return {

View File

@ -46,7 +46,7 @@ class _LSTMService(AbstractForecastService):
res = self.model.predict(self.list)
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
if self.return_conf:
res1 = [res.tolist(), res.tolist(), res.tolist()], None

View File

@ -41,7 +41,7 @@ def do_forecast(input_list, ts_list, algo_name, params):
def do_add_fc_params(params, json_obj):
""" add params into parameters """
if "forecast_rows" in json_obj:
params["fc_rows"] = int(json_obj["forecast_rows"])
params["rows"] = int(json_obj["forecast_rows"])
if "start" in json_obj:
params["start_ts"] = int(json_obj["start"])

View File

@ -147,7 +147,7 @@ def handle_forecast_req():
try:
res1 = do_forecast(payload[data_index], payload[ts_index], algo, params)
res = {"option": options, "rows": params["fc_rows"]}
res = {"option": options, "rows": params["rows"]}
res.update(res1)
app_logger.log_inst.debug("forecast result: %s", res)

View File

@ -77,14 +77,14 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
self.period = 0
self.start_ts = 0
self.time_step = 0
self.fc_rows = 0
self.rows = 0
self.return_conf = 1
self.conf = 0.05
def set_params(self, params: dict) -> None:
if not {'start_ts', 'time_step', 'fc_rows'}.issubset(params.keys()):
raise ValueError('params are missing, start_ts, time_step, fc_rows are all required')
if not {'start_ts', 'time_step', 'rows'}.issubset(params.keys()):
raise ValueError('params are missing, start_ts, time_step, rows are all required')
self.start_ts = int(params['start_ts'])
@ -93,9 +93,9 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
if self.time_step <= 0:
raise ValueError('time_step should be greater than 0')
self.fc_rows = int(params['fc_rows'])
self.rows = int(params['rows'])
if self.fc_rows <= 0:
if self.rows <= 0:
raise ValueError('fc rows is not specified yet')
self.period = int(params['period']) if 'period' in params else 0
@ -113,5 +113,5 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
def get_params(self):
return {
"period": self.period, "start": self.start_ts, "every": self.time_step,
"forecast_rows": self.fc_rows, "return_conf": self.return_conf, "conf": self.conf
"forecast_rows": self.rows, "return_conf": self.return_conf, "conf": self.conf
}

View File

@ -41,7 +41,7 @@ class ForecastTest(unittest.TestCase):
s.set_input_list(data, ts)
self.assertRaises(ValueError, s.execute)
s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
r = s.execute()
draw_fc_results(data, len(r["res"]) > 2, r["res"], len(r["res"][0]), "holtwinters")
@ -54,7 +54,7 @@ class ForecastTest(unittest.TestCase):
s.set_input_list(data, ts)
s.set_params(
{
"fc_rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000,
"rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000,
"time_step": 86400 * 30, "period": 12
}
)
@ -71,28 +71,28 @@ class ForecastTest(unittest.TestCase):
self.assertRaises(ValueError, s.set_params, {"trend": "mul"})
self.assertRaises(ValueError, s.set_params, {"trend": "mul", "fc_rows": 10})
self.assertRaises(ValueError, s.set_params, {"trend": "mul", "rows": 10})
self.assertRaises(ValueError, s.set_params, {"trend": "multi"})
self.assertRaises(ValueError, s.set_params, {"seasonal": "additive"})
self.assertRaises(ValueError, s.set_params, {
"fc_rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000,
"rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000,
"time_step": 86400 * 30, "period": 12}
)
self.assertRaises(ValueError, s.set_params,
{"fc_rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12}
{"rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12}
)
s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
self.assertRaises(ValueError, s.set_params, {"fc_rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30})
self.assertRaises(ValueError, s.set_params, {"rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30})
self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": "aaa", "time_step": "30"})
self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": "aaa", "time_step": "30"})
self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": 171000000, "time_step": 0})
self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": 171000000, "time_step": 0})
def test_arima(self):
"""arima algorithm check"""
@ -103,7 +103,7 @@ class ForecastTest(unittest.TestCase):
self.assertRaises(ValueError, s.execute)
s.set_params(
{"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12,
{"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12,
"start_p": 0, "max_p": 10, "start_q": 0, "max_q": 10}
)
r = s.execute()
@ -120,7 +120,7 @@ class ForecastTest(unittest.TestCase):
# s = loader.get_service("td_gpt_fc")
# s.set_input_list(data, ts)
#
# s.set_params({"host":'192.168.2.90:5000/ds_predict', 'fc_rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
# s.set_params({"host":'192.168.2.90:5000/ds_predict', 'rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
# r = s.execute()
#
# rows = len(r["res"][0])