fix(gpt): fix the bug for algo name case sensitive (#30362)
This commit is contained in:
parent
374a95dcca
commit
82ffeb82a9
|
@ -1257,18 +1257,18 @@ int32_t tDeserializeRetrieveIpWhite(void* buf, int32_t bufLen, SRetrieveIpWhiteR
|
|||
typedef struct {
|
||||
int32_t dnodeId;
|
||||
int64_t analVer;
|
||||
} SRetrieveAnalAlgoReq;
|
||||
} SRetrieveAnalyticsAlgoReq;
|
||||
|
||||
typedef struct {
|
||||
int64_t ver;
|
||||
SHashObj* hash; // algoname:algotype -> SAnalUrl
|
||||
} SRetrieveAnalAlgoRsp;
|
||||
} SRetrieveAnalyticAlgoRsp;
|
||||
|
||||
int32_t tSerializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq);
|
||||
int32_t tDeserializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq);
|
||||
int32_t tSerializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp);
|
||||
int32_t tDeserializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp);
|
||||
void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp* pRsp);
|
||||
int32_t tSerializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq);
|
||||
int32_t tDeserializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq);
|
||||
int32_t tSerializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp);
|
||||
int32_t tDeserializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp);
|
||||
void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp* pRsp);
|
||||
|
||||
typedef struct {
|
||||
int8_t alterType;
|
||||
|
|
|
@ -2297,7 +2297,7 @@ _exit:
|
|||
return code;
|
||||
}
|
||||
|
||||
int32_t tSerializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) {
|
||||
int32_t tSerializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) {
|
||||
SEncoder encoder = {0};
|
||||
int32_t code = 0;
|
||||
int32_t lino;
|
||||
|
@ -2319,7 +2319,7 @@ _exit:
|
|||
return tlen;
|
||||
}
|
||||
|
||||
int32_t tDeserializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) {
|
||||
int32_t tDeserializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) {
|
||||
SDecoder decoder = {0};
|
||||
int32_t code = 0;
|
||||
int32_t lino;
|
||||
|
@ -2336,7 +2336,7 @@ _exit:
|
|||
return code;
|
||||
}
|
||||
|
||||
int32_t tSerializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) {
|
||||
int32_t tSerializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) {
|
||||
SEncoder encoder = {0};
|
||||
int32_t code = 0;
|
||||
int32_t lino;
|
||||
|
@ -2387,7 +2387,7 @@ _exit:
|
|||
return tlen;
|
||||
}
|
||||
|
||||
int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) {
|
||||
int32_t tDeserializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) {
|
||||
if (pRsp->hash == NULL) {
|
||||
pRsp->hash = taosHashInit(64, MurmurHash3_32, true, HASH_ENTRY_LOCK);
|
||||
if (pRsp->hash == NULL) {
|
||||
|
@ -2425,7 +2425,10 @@ int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnal
|
|||
TAOS_CHECK_EXIT(tDecodeBinaryAlloc(&decoder, (void **)&url.url, NULL) < 0);
|
||||
}
|
||||
|
||||
TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, name, nameLen, &url, sizeof(SAnalyticsUrl)));
|
||||
char dstName[TSDB_ANALYTIC_ALGO_NAME_LEN] = {0};
|
||||
strntolower(dstName, name, nameLen);
|
||||
|
||||
TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, dstName, nameLen, &url, sizeof(SAnalyticsUrl)));
|
||||
}
|
||||
|
||||
tEndDecode(&decoder);
|
||||
|
@ -2435,7 +2438,7 @@ _exit:
|
|||
return code;
|
||||
}
|
||||
|
||||
void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp *pRsp) {
|
||||
void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp *pRsp) {
|
||||
void *pIter = taosHashIterate(pRsp->hash, NULL);
|
||||
while (pIter != NULL) {
|
||||
SAnalyticsUrl *pUrl = (SAnalyticsUrl *)pIter;
|
||||
|
|
|
@ -98,15 +98,15 @@ static void dmMayShouldUpdateAnalFunc(SDnodeMgmt *pMgmt, int64_t newVer) {
|
|||
if (oldVer == newVer) return;
|
||||
dDebug("analysis on dnode ver:%" PRId64 ", status ver:%" PRId64, oldVer, newVer);
|
||||
|
||||
SRetrieveAnalAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer};
|
||||
int32_t contLen = tSerializeRetrieveAnalAlgoReq(NULL, 0, &req);
|
||||
SRetrieveAnalyticsAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer};
|
||||
int32_t contLen = tSerializeRetrieveAnalyticAlgoReq(NULL, 0, &req);
|
||||
if (contLen < 0) {
|
||||
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));
|
||||
return;
|
||||
}
|
||||
|
||||
void *pHead = rpcMallocCont(contLen);
|
||||
contLen = tSerializeRetrieveAnalAlgoReq(pHead, contLen, &req);
|
||||
contLen = tSerializeRetrieveAnalyticAlgoReq(pHead, contLen, &req);
|
||||
if (contLen < 0) {
|
||||
rpcFreeCont(pHead);
|
||||
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));
|
||||
|
|
|
@ -116,13 +116,13 @@ static bool dmIsForbiddenIp(int8_t forbidden, char *user, uint32_t clientIp) {
|
|||
}
|
||||
}
|
||||
|
||||
static void dmUpdateAnalFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) {
|
||||
SRetrieveAnalAlgoRsp rsp = {0};
|
||||
if (tDeserializeRetrieveAnalAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) {
|
||||
static void dmUpdateAnalyticFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) {
|
||||
SRetrieveAnalyticAlgoRsp rsp = {0};
|
||||
if (tDeserializeRetrieveAnalyticAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) {
|
||||
taosAnalyUpdate(rsp.ver, rsp.hash);
|
||||
rsp.hash = NULL;
|
||||
}
|
||||
tFreeRetrieveAnalAlgoRsp(&rsp);
|
||||
tFreeRetrieveAnalyticAlgoRsp(&rsp);
|
||||
rpcFreeCont(pRpc->pCont);
|
||||
}
|
||||
|
||||
|
@ -176,7 +176,7 @@ static void dmProcessRpcMsg(SDnode *pDnode, SRpcMsg *pRpc, SEpSet *pEpSet) {
|
|||
dmUpdateRpcIpWhite(&pDnode->data, pTrans->serverRpc, pRpc);
|
||||
return;
|
||||
case TDMT_MND_RETRIEVE_ANAL_ALGO_RSP:
|
||||
dmUpdateAnalFunc(&pDnode->data, pTrans->serverRpc, pRpc);
|
||||
dmUpdateAnalyticFunc(&pDnode->data, pTrans->serverRpc, pRpc);
|
||||
return;
|
||||
default:
|
||||
break;
|
||||
|
|
|
@ -847,10 +847,10 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) {
|
|||
SAnalyticsUrl url;
|
||||
int32_t nameLen;
|
||||
char name[TSDB_ANALYTIC_ALGO_KEY_LEN];
|
||||
SRetrieveAnalAlgoReq req = {0};
|
||||
SRetrieveAnalAlgoRsp rsp = {0};
|
||||
SRetrieveAnalyticsAlgoReq req = {0};
|
||||
SRetrieveAnalyticAlgoRsp rsp = {0};
|
||||
|
||||
TAOS_CHECK_GOTO(tDeserializeRetrieveAnalAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER);
|
||||
TAOS_CHECK_GOTO(tDeserializeRetrieveAnalyticAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER);
|
||||
|
||||
rsp.ver = sdbGetTableVer(pSdb, SDB_ANODE);
|
||||
if (req.analVer != rsp.ver) {
|
||||
|
@ -906,15 +906,15 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) {
|
|||
}
|
||||
}
|
||||
|
||||
int32_t contLen = tSerializeRetrieveAnalAlgoRsp(NULL, 0, &rsp);
|
||||
int32_t contLen = tSerializeRetrieveAnalyticAlgoRsp(NULL, 0, &rsp);
|
||||
void *pHead = rpcMallocCont(contLen);
|
||||
(void)tSerializeRetrieveAnalAlgoRsp(pHead, contLen, &rsp);
|
||||
(void)tSerializeRetrieveAnalyticAlgoRsp(pHead, contLen, &rsp);
|
||||
|
||||
pReq->info.rspLen = contLen;
|
||||
pReq->info.rsp = pHead;
|
||||
|
||||
_OVER:
|
||||
tFreeRetrieveAnalAlgoRsp(&rsp);
|
||||
tFreeRetrieveAnalyticAlgoRsp(&rsp);
|
||||
TAOS_RETURN(code);
|
||||
}
|
||||
|
||||
|
|
|
@ -145,6 +145,7 @@ static int32_t forecastCloseBuf(SForecastSupp* pSupp, const char* id) {
|
|||
if (!hasWncheck) {
|
||||
qDebug("%s forecast wncheck not found from %s, use default:%" PRId64, id, pSupp->algoOpt, wncheck);
|
||||
}
|
||||
|
||||
code = taosAnalyBufWriteOptInt(pBuf, "wncheck", wncheck);
|
||||
if (code != 0) return code;
|
||||
|
||||
|
|
|
@ -86,11 +86,11 @@ class _ArimaService(AbstractForecastService):
|
|||
if len(self.list) > 3000:
|
||||
raise ValueError("number of input data is too large")
|
||||
|
||||
if self.fc_rows <= 0:
|
||||
if self.rows <= 0:
|
||||
raise ValueError("fc rows is not specified yet")
|
||||
|
||||
res, mse, model_info = self.__do_forecast_helper(self.fc_rows)
|
||||
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
|
||||
res, mse, model_info = self.__do_forecast_helper(self.rows)
|
||||
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
|
||||
|
||||
return {
|
||||
"mse": mse,
|
||||
|
|
|
@ -10,14 +10,14 @@ from taosanalytics.service import AbstractForecastService
|
|||
|
||||
|
||||
class _GPTService(AbstractForecastService):
|
||||
name = 'TDtsfm_1'
|
||||
name = 'tdtsfm_1'
|
||||
desc = "internal gpt forecast model based on transformer"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.table_name = None
|
||||
self.service_host = 'http://127.0.0.1:5000/ds_predict'
|
||||
self.service_host = 'http://127.0.0.1:5000/tdtsfm'
|
||||
self.headers = {'Content-Type': 'application/json'}
|
||||
|
||||
self.std = None
|
||||
|
@ -29,11 +29,11 @@ class _GPTService(AbstractForecastService):
|
|||
if self.list is None or len(self.list) < self.period:
|
||||
raise ValueError("number of input data is less than the periods")
|
||||
|
||||
if self.fc_rows <= 0:
|
||||
if self.rows <= 0:
|
||||
raise ValueError("fc rows is not specified yet")
|
||||
|
||||
# let's request the gpt service
|
||||
data = {"input": self.list, 'next_len': self.fc_rows}
|
||||
data = {"input": self.list, 'next_len': self.rows}
|
||||
try:
|
||||
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
|
||||
except Exception as e:
|
||||
|
@ -53,7 +53,7 @@ class _GPTService(AbstractForecastService):
|
|||
"res": [pred_y]
|
||||
}
|
||||
|
||||
insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows)
|
||||
insert_ts_list(res["res"], self.start_ts, self.time_step, self.rows)
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
@ -66,11 +66,11 @@ class _HoltWintersService(AbstractForecastService):
|
|||
if self.list is None or len(self.list) < self.period:
|
||||
raise ValueError("number of input data is less than the periods")
|
||||
|
||||
if self.fc_rows <= 0:
|
||||
if self.rows <= 0:
|
||||
raise ValueError("fc rows is not specified yet")
|
||||
|
||||
res, mse = self.__do_forecast_helper(self.list, self.fc_rows)
|
||||
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
|
||||
res, mse = self.__do_forecast_helper(self.list, self.rows)
|
||||
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
|
||||
|
||||
# add the conf range if required
|
||||
return {
|
||||
|
|
|
@ -46,7 +46,7 @@ class _LSTMService(AbstractForecastService):
|
|||
|
||||
res = self.model.predict(self.list)
|
||||
|
||||
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
|
||||
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
|
||||
|
||||
if self.return_conf:
|
||||
res1 = [res.tolist(), res.tolist(), res.tolist()], None
|
||||
|
|
|
@ -41,7 +41,7 @@ def do_forecast(input_list, ts_list, algo_name, params):
|
|||
def do_add_fc_params(params, json_obj):
|
||||
""" add params into parameters """
|
||||
if "forecast_rows" in json_obj:
|
||||
params["fc_rows"] = int(json_obj["forecast_rows"])
|
||||
params["rows"] = int(json_obj["forecast_rows"])
|
||||
|
||||
if "start" in json_obj:
|
||||
params["start_ts"] = int(json_obj["start"])
|
||||
|
|
|
@ -147,7 +147,7 @@ def handle_forecast_req():
|
|||
|
||||
try:
|
||||
res1 = do_forecast(payload[data_index], payload[ts_index], algo, params)
|
||||
res = {"option": options, "rows": params["fc_rows"]}
|
||||
res = {"option": options, "rows": params["rows"]}
|
||||
res.update(res1)
|
||||
|
||||
app_logger.log_inst.debug("forecast result: %s", res)
|
||||
|
|
|
@ -77,14 +77,14 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
|
|||
self.period = 0
|
||||
self.start_ts = 0
|
||||
self.time_step = 0
|
||||
self.fc_rows = 0
|
||||
self.rows = 0
|
||||
|
||||
self.return_conf = 1
|
||||
self.conf = 0.05
|
||||
|
||||
def set_params(self, params: dict) -> None:
|
||||
if not {'start_ts', 'time_step', 'fc_rows'}.issubset(params.keys()):
|
||||
raise ValueError('params are missing, start_ts, time_step, fc_rows are all required')
|
||||
if not {'start_ts', 'time_step', 'rows'}.issubset(params.keys()):
|
||||
raise ValueError('params are missing, start_ts, time_step, rows are all required')
|
||||
|
||||
self.start_ts = int(params['start_ts'])
|
||||
|
||||
|
@ -93,9 +93,9 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
|
|||
if self.time_step <= 0:
|
||||
raise ValueError('time_step should be greater than 0')
|
||||
|
||||
self.fc_rows = int(params['fc_rows'])
|
||||
self.rows = int(params['rows'])
|
||||
|
||||
if self.fc_rows <= 0:
|
||||
if self.rows <= 0:
|
||||
raise ValueError('fc rows is not specified yet')
|
||||
|
||||
self.period = int(params['period']) if 'period' in params else 0
|
||||
|
@ -113,5 +113,5 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
|
|||
def get_params(self):
|
||||
return {
|
||||
"period": self.period, "start": self.start_ts, "every": self.time_step,
|
||||
"forecast_rows": self.fc_rows, "return_conf": self.return_conf, "conf": self.conf
|
||||
"forecast_rows": self.rows, "return_conf": self.return_conf, "conf": self.conf
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ class ForecastTest(unittest.TestCase):
|
|||
s.set_input_list(data, ts)
|
||||
self.assertRaises(ValueError, s.execute)
|
||||
|
||||
s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
|
||||
s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
|
||||
|
||||
r = s.execute()
|
||||
draw_fc_results(data, len(r["res"]) > 2, r["res"], len(r["res"][0]), "holtwinters")
|
||||
|
@ -54,7 +54,7 @@ class ForecastTest(unittest.TestCase):
|
|||
s.set_input_list(data, ts)
|
||||
s.set_params(
|
||||
{
|
||||
"fc_rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000,
|
||||
"rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000,
|
||||
"time_step": 86400 * 30, "period": 12
|
||||
}
|
||||
)
|
||||
|
@ -71,28 +71,28 @@ class ForecastTest(unittest.TestCase):
|
|||
|
||||
self.assertRaises(ValueError, s.set_params, {"trend": "mul"})
|
||||
|
||||
self.assertRaises(ValueError, s.set_params, {"trend": "mul", "fc_rows": 10})
|
||||
self.assertRaises(ValueError, s.set_params, {"trend": "mul", "rows": 10})
|
||||
|
||||
self.assertRaises(ValueError, s.set_params, {"trend": "multi"})
|
||||
|
||||
self.assertRaises(ValueError, s.set_params, {"seasonal": "additive"})
|
||||
|
||||
self.assertRaises(ValueError, s.set_params, {
|
||||
"fc_rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000,
|
||||
"rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000,
|
||||
"time_step": 86400 * 30, "period": 12}
|
||||
)
|
||||
|
||||
self.assertRaises(ValueError, s.set_params,
|
||||
{"fc_rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12}
|
||||
{"rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12}
|
||||
)
|
||||
|
||||
s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
|
||||
s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
|
||||
|
||||
self.assertRaises(ValueError, s.set_params, {"fc_rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30})
|
||||
self.assertRaises(ValueError, s.set_params, {"rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30})
|
||||
|
||||
self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": "aaa", "time_step": "30"})
|
||||
self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": "aaa", "time_step": "30"})
|
||||
|
||||
self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": 171000000, "time_step": 0})
|
||||
self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": 171000000, "time_step": 0})
|
||||
|
||||
def test_arima(self):
|
||||
"""arima algorithm check"""
|
||||
|
@ -103,7 +103,7 @@ class ForecastTest(unittest.TestCase):
|
|||
self.assertRaises(ValueError, s.execute)
|
||||
|
||||
s.set_params(
|
||||
{"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12,
|
||||
{"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12,
|
||||
"start_p": 0, "max_p": 10, "start_q": 0, "max_q": 10}
|
||||
)
|
||||
r = s.execute()
|
||||
|
@ -120,7 +120,7 @@ class ForecastTest(unittest.TestCase):
|
|||
# s = loader.get_service("td_gpt_fc")
|
||||
# s.set_input_list(data, ts)
|
||||
#
|
||||
# s.set_params({"host":'192.168.2.90:5000/ds_predict', 'fc_rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
|
||||
# s.set_params({"host":'192.168.2.90:5000/ds_predict', 'rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
|
||||
# r = s.execute()
|
||||
#
|
||||
# rows = len(r["res"][0])
|
||||
|
|
Loading…
Reference in New Issue