fix(gpt): fix the bug for algo name case sensitive (#30362)
This commit is contained in:
parent
374a95dcca
commit
82ffeb82a9
|
@ -1257,18 +1257,18 @@ int32_t tDeserializeRetrieveIpWhite(void* buf, int32_t bufLen, SRetrieveIpWhiteR
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t dnodeId;
|
int32_t dnodeId;
|
||||||
int64_t analVer;
|
int64_t analVer;
|
||||||
} SRetrieveAnalAlgoReq;
|
} SRetrieveAnalyticsAlgoReq;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t ver;
|
int64_t ver;
|
||||||
SHashObj* hash; // algoname:algotype -> SAnalUrl
|
SHashObj* hash; // algoname:algotype -> SAnalUrl
|
||||||
} SRetrieveAnalAlgoRsp;
|
} SRetrieveAnalyticAlgoRsp;
|
||||||
|
|
||||||
int32_t tSerializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq);
|
int32_t tSerializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq);
|
||||||
int32_t tDeserializeRetrieveAnalAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalAlgoReq* pReq);
|
int32_t tDeserializeRetrieveAnalyticAlgoReq(void* buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq* pReq);
|
||||||
int32_t tSerializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp);
|
int32_t tSerializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp);
|
||||||
int32_t tDeserializeRetrieveAnalAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalAlgoRsp* pRsp);
|
int32_t tDeserializeRetrieveAnalyticAlgoRsp(void* buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp* pRsp);
|
||||||
void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp* pRsp);
|
void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp* pRsp);
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int8_t alterType;
|
int8_t alterType;
|
||||||
|
|
|
@ -2297,7 +2297,7 @@ _exit:
|
||||||
return code;
|
return code;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t tSerializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) {
|
int32_t tSerializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) {
|
||||||
SEncoder encoder = {0};
|
SEncoder encoder = {0};
|
||||||
int32_t code = 0;
|
int32_t code = 0;
|
||||||
int32_t lino;
|
int32_t lino;
|
||||||
|
@ -2319,7 +2319,7 @@ _exit:
|
||||||
return tlen;
|
return tlen;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t tDeserializeRetrieveAnalAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalAlgoReq *pReq) {
|
int32_t tDeserializeRetrieveAnalyticAlgoReq(void *buf, int32_t bufLen, SRetrieveAnalyticsAlgoReq *pReq) {
|
||||||
SDecoder decoder = {0};
|
SDecoder decoder = {0};
|
||||||
int32_t code = 0;
|
int32_t code = 0;
|
||||||
int32_t lino;
|
int32_t lino;
|
||||||
|
@ -2336,7 +2336,7 @@ _exit:
|
||||||
return code;
|
return code;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t tSerializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) {
|
int32_t tSerializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) {
|
||||||
SEncoder encoder = {0};
|
SEncoder encoder = {0};
|
||||||
int32_t code = 0;
|
int32_t code = 0;
|
||||||
int32_t lino;
|
int32_t lino;
|
||||||
|
@ -2387,7 +2387,7 @@ _exit:
|
||||||
return tlen;
|
return tlen;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalAlgoRsp *pRsp) {
|
int32_t tDeserializeRetrieveAnalyticAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnalyticAlgoRsp *pRsp) {
|
||||||
if (pRsp->hash == NULL) {
|
if (pRsp->hash == NULL) {
|
||||||
pRsp->hash = taosHashInit(64, MurmurHash3_32, true, HASH_ENTRY_LOCK);
|
pRsp->hash = taosHashInit(64, MurmurHash3_32, true, HASH_ENTRY_LOCK);
|
||||||
if (pRsp->hash == NULL) {
|
if (pRsp->hash == NULL) {
|
||||||
|
@ -2425,7 +2425,10 @@ int32_t tDeserializeRetrieveAnalAlgoRsp(void *buf, int32_t bufLen, SRetrieveAnal
|
||||||
TAOS_CHECK_EXIT(tDecodeBinaryAlloc(&decoder, (void **)&url.url, NULL) < 0);
|
TAOS_CHECK_EXIT(tDecodeBinaryAlloc(&decoder, (void **)&url.url, NULL) < 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, name, nameLen, &url, sizeof(SAnalyticsUrl)));
|
char dstName[TSDB_ANALYTIC_ALGO_NAME_LEN] = {0};
|
||||||
|
strntolower(dstName, name, nameLen);
|
||||||
|
|
||||||
|
TAOS_CHECK_EXIT(taosHashPut(pRsp->hash, dstName, nameLen, &url, sizeof(SAnalyticsUrl)));
|
||||||
}
|
}
|
||||||
|
|
||||||
tEndDecode(&decoder);
|
tEndDecode(&decoder);
|
||||||
|
@ -2435,7 +2438,7 @@ _exit:
|
||||||
return code;
|
return code;
|
||||||
}
|
}
|
||||||
|
|
||||||
void tFreeRetrieveAnalAlgoRsp(SRetrieveAnalAlgoRsp *pRsp) {
|
void tFreeRetrieveAnalyticAlgoRsp(SRetrieveAnalyticAlgoRsp *pRsp) {
|
||||||
void *pIter = taosHashIterate(pRsp->hash, NULL);
|
void *pIter = taosHashIterate(pRsp->hash, NULL);
|
||||||
while (pIter != NULL) {
|
while (pIter != NULL) {
|
||||||
SAnalyticsUrl *pUrl = (SAnalyticsUrl *)pIter;
|
SAnalyticsUrl *pUrl = (SAnalyticsUrl *)pIter;
|
||||||
|
|
|
@ -98,15 +98,15 @@ static void dmMayShouldUpdateAnalFunc(SDnodeMgmt *pMgmt, int64_t newVer) {
|
||||||
if (oldVer == newVer) return;
|
if (oldVer == newVer) return;
|
||||||
dDebug("analysis on dnode ver:%" PRId64 ", status ver:%" PRId64, oldVer, newVer);
|
dDebug("analysis on dnode ver:%" PRId64 ", status ver:%" PRId64, oldVer, newVer);
|
||||||
|
|
||||||
SRetrieveAnalAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer};
|
SRetrieveAnalyticsAlgoReq req = {.dnodeId = pMgmt->pData->dnodeId, .analVer = oldVer};
|
||||||
int32_t contLen = tSerializeRetrieveAnalAlgoReq(NULL, 0, &req);
|
int32_t contLen = tSerializeRetrieveAnalyticAlgoReq(NULL, 0, &req);
|
||||||
if (contLen < 0) {
|
if (contLen < 0) {
|
||||||
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));
|
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
void *pHead = rpcMallocCont(contLen);
|
void *pHead = rpcMallocCont(contLen);
|
||||||
contLen = tSerializeRetrieveAnalAlgoReq(pHead, contLen, &req);
|
contLen = tSerializeRetrieveAnalyticAlgoReq(pHead, contLen, &req);
|
||||||
if (contLen < 0) {
|
if (contLen < 0) {
|
||||||
rpcFreeCont(pHead);
|
rpcFreeCont(pHead);
|
||||||
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));
|
dError("failed to serialize analysis function ver request since %s", tstrerror(contLen));
|
||||||
|
|
|
@ -116,13 +116,13 @@ static bool dmIsForbiddenIp(int8_t forbidden, char *user, uint32_t clientIp) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dmUpdateAnalFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) {
|
static void dmUpdateAnalyticFunc(SDnodeData *pData, void *pTrans, SRpcMsg *pRpc) {
|
||||||
SRetrieveAnalAlgoRsp rsp = {0};
|
SRetrieveAnalyticAlgoRsp rsp = {0};
|
||||||
if (tDeserializeRetrieveAnalAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) {
|
if (tDeserializeRetrieveAnalyticAlgoRsp(pRpc->pCont, pRpc->contLen, &rsp) == 0) {
|
||||||
taosAnalyUpdate(rsp.ver, rsp.hash);
|
taosAnalyUpdate(rsp.ver, rsp.hash);
|
||||||
rsp.hash = NULL;
|
rsp.hash = NULL;
|
||||||
}
|
}
|
||||||
tFreeRetrieveAnalAlgoRsp(&rsp);
|
tFreeRetrieveAnalyticAlgoRsp(&rsp);
|
||||||
rpcFreeCont(pRpc->pCont);
|
rpcFreeCont(pRpc->pCont);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,7 +176,7 @@ static void dmProcessRpcMsg(SDnode *pDnode, SRpcMsg *pRpc, SEpSet *pEpSet) {
|
||||||
dmUpdateRpcIpWhite(&pDnode->data, pTrans->serverRpc, pRpc);
|
dmUpdateRpcIpWhite(&pDnode->data, pTrans->serverRpc, pRpc);
|
||||||
return;
|
return;
|
||||||
case TDMT_MND_RETRIEVE_ANAL_ALGO_RSP:
|
case TDMT_MND_RETRIEVE_ANAL_ALGO_RSP:
|
||||||
dmUpdateAnalFunc(&pDnode->data, pTrans->serverRpc, pRpc);
|
dmUpdateAnalyticFunc(&pDnode->data, pTrans->serverRpc, pRpc);
|
||||||
return;
|
return;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -847,10 +847,10 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) {
|
||||||
SAnalyticsUrl url;
|
SAnalyticsUrl url;
|
||||||
int32_t nameLen;
|
int32_t nameLen;
|
||||||
char name[TSDB_ANALYTIC_ALGO_KEY_LEN];
|
char name[TSDB_ANALYTIC_ALGO_KEY_LEN];
|
||||||
SRetrieveAnalAlgoReq req = {0};
|
SRetrieveAnalyticsAlgoReq req = {0};
|
||||||
SRetrieveAnalAlgoRsp rsp = {0};
|
SRetrieveAnalyticAlgoRsp rsp = {0};
|
||||||
|
|
||||||
TAOS_CHECK_GOTO(tDeserializeRetrieveAnalAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER);
|
TAOS_CHECK_GOTO(tDeserializeRetrieveAnalyticAlgoReq(pReq->pCont, pReq->contLen, &req), NULL, _OVER);
|
||||||
|
|
||||||
rsp.ver = sdbGetTableVer(pSdb, SDB_ANODE);
|
rsp.ver = sdbGetTableVer(pSdb, SDB_ANODE);
|
||||||
if (req.analVer != rsp.ver) {
|
if (req.analVer != rsp.ver) {
|
||||||
|
@ -906,15 +906,15 @@ static int32_t mndProcessAnalAlgoReq(SRpcMsg *pReq) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t contLen = tSerializeRetrieveAnalAlgoRsp(NULL, 0, &rsp);
|
int32_t contLen = tSerializeRetrieveAnalyticAlgoRsp(NULL, 0, &rsp);
|
||||||
void *pHead = rpcMallocCont(contLen);
|
void *pHead = rpcMallocCont(contLen);
|
||||||
(void)tSerializeRetrieveAnalAlgoRsp(pHead, contLen, &rsp);
|
(void)tSerializeRetrieveAnalyticAlgoRsp(pHead, contLen, &rsp);
|
||||||
|
|
||||||
pReq->info.rspLen = contLen;
|
pReq->info.rspLen = contLen;
|
||||||
pReq->info.rsp = pHead;
|
pReq->info.rsp = pHead;
|
||||||
|
|
||||||
_OVER:
|
_OVER:
|
||||||
tFreeRetrieveAnalAlgoRsp(&rsp);
|
tFreeRetrieveAnalyticAlgoRsp(&rsp);
|
||||||
TAOS_RETURN(code);
|
TAOS_RETURN(code);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -145,6 +145,7 @@ static int32_t forecastCloseBuf(SForecastSupp* pSupp, const char* id) {
|
||||||
if (!hasWncheck) {
|
if (!hasWncheck) {
|
||||||
qDebug("%s forecast wncheck not found from %s, use default:%" PRId64, id, pSupp->algoOpt, wncheck);
|
qDebug("%s forecast wncheck not found from %s, use default:%" PRId64, id, pSupp->algoOpt, wncheck);
|
||||||
}
|
}
|
||||||
|
|
||||||
code = taosAnalyBufWriteOptInt(pBuf, "wncheck", wncheck);
|
code = taosAnalyBufWriteOptInt(pBuf, "wncheck", wncheck);
|
||||||
if (code != 0) return code;
|
if (code != 0) return code;
|
||||||
|
|
||||||
|
|
|
@ -86,11 +86,11 @@ class _ArimaService(AbstractForecastService):
|
||||||
if len(self.list) > 3000:
|
if len(self.list) > 3000:
|
||||||
raise ValueError("number of input data is too large")
|
raise ValueError("number of input data is too large")
|
||||||
|
|
||||||
if self.fc_rows <= 0:
|
if self.rows <= 0:
|
||||||
raise ValueError("fc rows is not specified yet")
|
raise ValueError("fc rows is not specified yet")
|
||||||
|
|
||||||
res, mse, model_info = self.__do_forecast_helper(self.fc_rows)
|
res, mse, model_info = self.__do_forecast_helper(self.rows)
|
||||||
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
|
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"mse": mse,
|
"mse": mse,
|
||||||
|
|
|
@ -10,14 +10,14 @@ from taosanalytics.service import AbstractForecastService
|
||||||
|
|
||||||
|
|
||||||
class _GPTService(AbstractForecastService):
|
class _GPTService(AbstractForecastService):
|
||||||
name = 'TDtsfm_1'
|
name = 'tdtsfm_1'
|
||||||
desc = "internal gpt forecast model based on transformer"
|
desc = "internal gpt forecast model based on transformer"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.table_name = None
|
self.table_name = None
|
||||||
self.service_host = 'http://127.0.0.1:5000/ds_predict'
|
self.service_host = 'http://127.0.0.1:5000/tdtsfm'
|
||||||
self.headers = {'Content-Type': 'application/json'}
|
self.headers = {'Content-Type': 'application/json'}
|
||||||
|
|
||||||
self.std = None
|
self.std = None
|
||||||
|
@ -29,11 +29,11 @@ class _GPTService(AbstractForecastService):
|
||||||
if self.list is None or len(self.list) < self.period:
|
if self.list is None or len(self.list) < self.period:
|
||||||
raise ValueError("number of input data is less than the periods")
|
raise ValueError("number of input data is less than the periods")
|
||||||
|
|
||||||
if self.fc_rows <= 0:
|
if self.rows <= 0:
|
||||||
raise ValueError("fc rows is not specified yet")
|
raise ValueError("fc rows is not specified yet")
|
||||||
|
|
||||||
# let's request the gpt service
|
# let's request the gpt service
|
||||||
data = {"input": self.list, 'next_len': self.fc_rows}
|
data = {"input": self.list, 'next_len': self.rows}
|
||||||
try:
|
try:
|
||||||
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
|
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -53,7 +53,7 @@ class _GPTService(AbstractForecastService):
|
||||||
"res": [pred_y]
|
"res": [pred_y]
|
||||||
}
|
}
|
||||||
|
|
||||||
insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows)
|
insert_ts_list(res["res"], self.start_ts, self.time_step, self.rows)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -66,11 +66,11 @@ class _HoltWintersService(AbstractForecastService):
|
||||||
if self.list is None or len(self.list) < self.period:
|
if self.list is None or len(self.list) < self.period:
|
||||||
raise ValueError("number of input data is less than the periods")
|
raise ValueError("number of input data is less than the periods")
|
||||||
|
|
||||||
if self.fc_rows <= 0:
|
if self.rows <= 0:
|
||||||
raise ValueError("fc rows is not specified yet")
|
raise ValueError("fc rows is not specified yet")
|
||||||
|
|
||||||
res, mse = self.__do_forecast_helper(self.list, self.fc_rows)
|
res, mse = self.__do_forecast_helper(self.list, self.rows)
|
||||||
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
|
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
|
||||||
|
|
||||||
# add the conf range if required
|
# add the conf range if required
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -46,7 +46,7 @@ class _LSTMService(AbstractForecastService):
|
||||||
|
|
||||||
res = self.model.predict(self.list)
|
res = self.model.predict(self.list)
|
||||||
|
|
||||||
insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
|
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
|
||||||
|
|
||||||
if self.return_conf:
|
if self.return_conf:
|
||||||
res1 = [res.tolist(), res.tolist(), res.tolist()], None
|
res1 = [res.tolist(), res.tolist(), res.tolist()], None
|
||||||
|
|
|
@ -41,7 +41,7 @@ def do_forecast(input_list, ts_list, algo_name, params):
|
||||||
def do_add_fc_params(params, json_obj):
|
def do_add_fc_params(params, json_obj):
|
||||||
""" add params into parameters """
|
""" add params into parameters """
|
||||||
if "forecast_rows" in json_obj:
|
if "forecast_rows" in json_obj:
|
||||||
params["fc_rows"] = int(json_obj["forecast_rows"])
|
params["rows"] = int(json_obj["forecast_rows"])
|
||||||
|
|
||||||
if "start" in json_obj:
|
if "start" in json_obj:
|
||||||
params["start_ts"] = int(json_obj["start"])
|
params["start_ts"] = int(json_obj["start"])
|
||||||
|
|
|
@ -147,7 +147,7 @@ def handle_forecast_req():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res1 = do_forecast(payload[data_index], payload[ts_index], algo, params)
|
res1 = do_forecast(payload[data_index], payload[ts_index], algo, params)
|
||||||
res = {"option": options, "rows": params["fc_rows"]}
|
res = {"option": options, "rows": params["rows"]}
|
||||||
res.update(res1)
|
res.update(res1)
|
||||||
|
|
||||||
app_logger.log_inst.debug("forecast result: %s", res)
|
app_logger.log_inst.debug("forecast result: %s", res)
|
||||||
|
|
|
@ -77,14 +77,14 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
|
||||||
self.period = 0
|
self.period = 0
|
||||||
self.start_ts = 0
|
self.start_ts = 0
|
||||||
self.time_step = 0
|
self.time_step = 0
|
||||||
self.fc_rows = 0
|
self.rows = 0
|
||||||
|
|
||||||
self.return_conf = 1
|
self.return_conf = 1
|
||||||
self.conf = 0.05
|
self.conf = 0.05
|
||||||
|
|
||||||
def set_params(self, params: dict) -> None:
|
def set_params(self, params: dict) -> None:
|
||||||
if not {'start_ts', 'time_step', 'fc_rows'}.issubset(params.keys()):
|
if not {'start_ts', 'time_step', 'rows'}.issubset(params.keys()):
|
||||||
raise ValueError('params are missing, start_ts, time_step, fc_rows are all required')
|
raise ValueError('params are missing, start_ts, time_step, rows are all required')
|
||||||
|
|
||||||
self.start_ts = int(params['start_ts'])
|
self.start_ts = int(params['start_ts'])
|
||||||
|
|
||||||
|
@ -93,9 +93,9 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
|
||||||
if self.time_step <= 0:
|
if self.time_step <= 0:
|
||||||
raise ValueError('time_step should be greater than 0')
|
raise ValueError('time_step should be greater than 0')
|
||||||
|
|
||||||
self.fc_rows = int(params['fc_rows'])
|
self.rows = int(params['rows'])
|
||||||
|
|
||||||
if self.fc_rows <= 0:
|
if self.rows <= 0:
|
||||||
raise ValueError('fc rows is not specified yet')
|
raise ValueError('fc rows is not specified yet')
|
||||||
|
|
||||||
self.period = int(params['period']) if 'period' in params else 0
|
self.period = int(params['period']) if 'period' in params else 0
|
||||||
|
@ -113,5 +113,5 @@ class AbstractForecastService(AbstractAnalyticsService, ABC):
|
||||||
def get_params(self):
|
def get_params(self):
|
||||||
return {
|
return {
|
||||||
"period": self.period, "start": self.start_ts, "every": self.time_step,
|
"period": self.period, "start": self.start_ts, "every": self.time_step,
|
||||||
"forecast_rows": self.fc_rows, "return_conf": self.return_conf, "conf": self.conf
|
"forecast_rows": self.rows, "return_conf": self.return_conf, "conf": self.conf
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ class ForecastTest(unittest.TestCase):
|
||||||
s.set_input_list(data, ts)
|
s.set_input_list(data, ts)
|
||||||
self.assertRaises(ValueError, s.execute)
|
self.assertRaises(ValueError, s.execute)
|
||||||
|
|
||||||
s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
|
s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
|
||||||
|
|
||||||
r = s.execute()
|
r = s.execute()
|
||||||
draw_fc_results(data, len(r["res"]) > 2, r["res"], len(r["res"][0]), "holtwinters")
|
draw_fc_results(data, len(r["res"]) > 2, r["res"], len(r["res"][0]), "holtwinters")
|
||||||
|
@ -54,7 +54,7 @@ class ForecastTest(unittest.TestCase):
|
||||||
s.set_input_list(data, ts)
|
s.set_input_list(data, ts)
|
||||||
s.set_params(
|
s.set_params(
|
||||||
{
|
{
|
||||||
"fc_rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000,
|
"rows": 10, "trend": 'mul', "seasonal": 'mul', "start_ts": 171000000,
|
||||||
"time_step": 86400 * 30, "period": 12
|
"time_step": 86400 * 30, "period": 12
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -71,28 +71,28 @@ class ForecastTest(unittest.TestCase):
|
||||||
|
|
||||||
self.assertRaises(ValueError, s.set_params, {"trend": "mul"})
|
self.assertRaises(ValueError, s.set_params, {"trend": "mul"})
|
||||||
|
|
||||||
self.assertRaises(ValueError, s.set_params, {"trend": "mul", "fc_rows": 10})
|
self.assertRaises(ValueError, s.set_params, {"trend": "mul", "rows": 10})
|
||||||
|
|
||||||
self.assertRaises(ValueError, s.set_params, {"trend": "multi"})
|
self.assertRaises(ValueError, s.set_params, {"trend": "multi"})
|
||||||
|
|
||||||
self.assertRaises(ValueError, s.set_params, {"seasonal": "additive"})
|
self.assertRaises(ValueError, s.set_params, {"seasonal": "additive"})
|
||||||
|
|
||||||
self.assertRaises(ValueError, s.set_params, {
|
self.assertRaises(ValueError, s.set_params, {
|
||||||
"fc_rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000,
|
"rows": 10, "trend": 'multi', "seasonal": 'addi', "start_ts": 171000000,
|
||||||
"time_step": 86400 * 30, "period": 12}
|
"time_step": 86400 * 30, "period": 12}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertRaises(ValueError, s.set_params,
|
self.assertRaises(ValueError, s.set_params,
|
||||||
{"fc_rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12}
|
{"rows": 10, "trend": 'mul', "seasonal": 'add', "time_step": 86400 * 30, "period": 12}
|
||||||
)
|
)
|
||||||
|
|
||||||
s.set_params({"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
|
s.set_params({"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30})
|
||||||
|
|
||||||
self.assertRaises(ValueError, s.set_params, {"fc_rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30})
|
self.assertRaises(ValueError, s.set_params, {"rows": 'abc', "start_ts": 171000000, "time_step": 86400 * 30})
|
||||||
|
|
||||||
self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": "aaa", "time_step": "30"})
|
self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": "aaa", "time_step": "30"})
|
||||||
|
|
||||||
self.assertRaises(ValueError, s.set_params, {"fc_rows": 10, "start_ts": 171000000, "time_step": 0})
|
self.assertRaises(ValueError, s.set_params, {"rows": 10, "start_ts": 171000000, "time_step": 0})
|
||||||
|
|
||||||
def test_arima(self):
|
def test_arima(self):
|
||||||
"""arima algorithm check"""
|
"""arima algorithm check"""
|
||||||
|
@ -103,7 +103,7 @@ class ForecastTest(unittest.TestCase):
|
||||||
self.assertRaises(ValueError, s.execute)
|
self.assertRaises(ValueError, s.execute)
|
||||||
|
|
||||||
s.set_params(
|
s.set_params(
|
||||||
{"fc_rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12,
|
{"rows": 10, "start_ts": 171000000, "time_step": 86400 * 30, "period": 12,
|
||||||
"start_p": 0, "max_p": 10, "start_q": 0, "max_q": 10}
|
"start_p": 0, "max_p": 10, "start_q": 0, "max_q": 10}
|
||||||
)
|
)
|
||||||
r = s.execute()
|
r = s.execute()
|
||||||
|
@ -120,7 +120,7 @@ class ForecastTest(unittest.TestCase):
|
||||||
# s = loader.get_service("td_gpt_fc")
|
# s = loader.get_service("td_gpt_fc")
|
||||||
# s.set_input_list(data, ts)
|
# s.set_input_list(data, ts)
|
||||||
#
|
#
|
||||||
# s.set_params({"host":'192.168.2.90:5000/ds_predict', 'fc_rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
|
# s.set_params({"host":'192.168.2.90:5000/ds_predict', 'rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
|
||||||
# r = s.execute()
|
# r = s.execute()
|
||||||
#
|
#
|
||||||
# rows = len(r["res"][0])
|
# rows = len(r["res"][0])
|
||||||
|
|
Loading…
Reference in New Issue