fix(gpt): update the test cases.

This commit is contained in:
Haojun Liao 2025-03-19 17:31:56 +08:00
parent 24de2e76b5
commit fa2803121a
2 changed files with 23 additions and 19 deletions

View File

@ -39,13 +39,13 @@ class _GPTService(AbstractForecastService):
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
except Exception as e:
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ", str(e))
raise ValueError("error")
raise e
# print(response)
if response.status_code != 200:
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ")
raise ValueError("invalid host url")
pred_y = response.json()['output']
# print(f"pred_y len:{len(pred_y)}")
# print(f"pred_y:{pred_y}")
res = {
"res": [pred_y]
@ -54,18 +54,6 @@ class _GPTService(AbstractForecastService):
insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows)
return res
# insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
#
# if self.return_conf:
# res1 = [res.tolist(), res.tolist(), res.tolist()], None
# else:
# res1 = [res.tolist()], None
#
# # add the conf range if required
# return {
# "mse": None,
# "res": res1
# }
def set_params(self, params):
super().set_params(params)
@ -73,7 +61,7 @@ class _GPTService(AbstractForecastService):
if "host" not in params:
raise ValueError("gpt service host needs to be specified")
self.service_host = params['host'].trim()
self.service_host = params['host']
if self.service_host.startswith("https://"):
self.service_host = self.service_host.replace("https://", "http://")

View File

@ -8,7 +8,7 @@ import pandas as pd
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
from taosanalytics.algo.forecast import draw_fc_results
from taosanalytics.conf import setup_log_info
from taosanalytics.conf import setup_log_info, app_logger
from taosanalytics.servicemgmt import loader
@ -30,7 +30,8 @@ class ForecastTest(unittest.TestCase):
ts_list = data[['Passengers']].index.tolist()
dst_list = [int(item.timestamp()) for item in ts_list]
return data[['Passengers']].values.tolist(), dst_list
return data['Passengers'].values.tolist(), dst_list
def test_holt_winters_forecast(self):
""" test holt winters forecast with invalid and then valid parameters"""
@ -111,5 +112,20 @@ class ForecastTest(unittest.TestCase):
draw_fc_results(data, len(r["res"]) > 1, r["res"], rows, "arima")
def test_gpt_fc(self):
"""for local test only, disabled it in github action"""
data, ts = self.get_input_list()
pass
s = loader.get_service("td_gpt_fc")
s.set_input_list(data, ts)
s.set_params({"host":'192.168.2.90:5000/ds_predict', 'fc_rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
r = s.execute()
rows = len(r["res"][0])
draw_fc_results(data, False, r["res"], rows, "gpt")
if __name__ == '__main__':
unittest.main()