fix(gpt): update the test cases.
This commit is contained in:
parent
24de2e76b5
commit
fa2803121a
|
@ -39,13 +39,13 @@ class _GPTService(AbstractForecastService):
|
||||||
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
|
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ", str(e))
|
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ", str(e))
|
||||||
raise ValueError("error")
|
raise e
|
||||||
|
|
||||||
# print(response)
|
if response.status_code != 200:
|
||||||
|
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ")
|
||||||
|
raise ValueError("invalid host url")
|
||||||
|
|
||||||
pred_y = response.json()['output']
|
pred_y = response.json()['output']
|
||||||
# print(f"pred_y len:{len(pred_y)}")
|
|
||||||
# print(f"pred_y:{pred_y}")
|
|
||||||
|
|
||||||
res = {
|
res = {
|
||||||
"res": [pred_y]
|
"res": [pred_y]
|
||||||
|
@ -54,18 +54,6 @@ class _GPTService(AbstractForecastService):
|
||||||
insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows)
|
insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
|
|
||||||
#
|
|
||||||
# if self.return_conf:
|
|
||||||
# res1 = [res.tolist(), res.tolist(), res.tolist()], None
|
|
||||||
# else:
|
|
||||||
# res1 = [res.tolist()], None
|
|
||||||
#
|
|
||||||
# # add the conf range if required
|
|
||||||
# return {
|
|
||||||
# "mse": None,
|
|
||||||
# "res": res1
|
|
||||||
# }
|
|
||||||
|
|
||||||
def set_params(self, params):
|
def set_params(self, params):
|
||||||
super().set_params(params)
|
super().set_params(params)
|
||||||
|
@ -73,7 +61,7 @@ class _GPTService(AbstractForecastService):
|
||||||
if "host" not in params:
|
if "host" not in params:
|
||||||
raise ValueError("gpt service host needs to be specified")
|
raise ValueError("gpt service host needs to be specified")
|
||||||
|
|
||||||
self.service_host = params['host'].trim()
|
self.service_host = params['host']
|
||||||
|
|
||||||
if self.service_host.startswith("https://"):
|
if self.service_host.startswith("https://"):
|
||||||
self.service_host = self.service_host.replace("https://", "http://")
|
self.service_host = self.service_host.replace("https://", "http://")
|
||||||
|
|
|
@ -8,7 +8,7 @@ import pandas as pd
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
|
||||||
|
|
||||||
from taosanalytics.algo.forecast import draw_fc_results
|
from taosanalytics.algo.forecast import draw_fc_results
|
||||||
from taosanalytics.conf import setup_log_info
|
from taosanalytics.conf import setup_log_info, app_logger
|
||||||
from taosanalytics.servicemgmt import loader
|
from taosanalytics.servicemgmt import loader
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +30,8 @@ class ForecastTest(unittest.TestCase):
|
||||||
ts_list = data[['Passengers']].index.tolist()
|
ts_list = data[['Passengers']].index.tolist()
|
||||||
dst_list = [int(item.timestamp()) for item in ts_list]
|
dst_list = [int(item.timestamp()) for item in ts_list]
|
||||||
|
|
||||||
return data[['Passengers']].values.tolist(), dst_list
|
return data['Passengers'].values.tolist(), dst_list
|
||||||
|
|
||||||
|
|
||||||
def test_holt_winters_forecast(self):
|
def test_holt_winters_forecast(self):
|
||||||
""" test holt winters forecast with invalid and then valid parameters"""
|
""" test holt winters forecast with invalid and then valid parameters"""
|
||||||
|
@ -111,5 +112,20 @@ class ForecastTest(unittest.TestCase):
|
||||||
draw_fc_results(data, len(r["res"]) > 1, r["res"], rows, "arima")
|
draw_fc_results(data, len(r["res"]) > 1, r["res"], rows, "arima")
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpt_fc(self):
|
||||||
|
"""for local test only, disabled it in github action"""
|
||||||
|
data, ts = self.get_input_list()
|
||||||
|
pass
|
||||||
|
|
||||||
|
s = loader.get_service("td_gpt_fc")
|
||||||
|
s.set_input_list(data, ts)
|
||||||
|
|
||||||
|
s.set_params({"host":'192.168.2.90:5000/ds_predict', 'fc_rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
|
||||||
|
r = s.execute()
|
||||||
|
|
||||||
|
rows = len(r["res"][0])
|
||||||
|
draw_fc_results(data, False, r["res"], rows, "gpt")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue