fix(gpt): update the test cases.
This commit is contained in:
parent
24de2e76b5
commit
fa2803121a
|
@ -39,13 +39,13 @@ class _GPTService(AbstractForecastService):
|
|||
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
|
||||
except Exception as e:
|
||||
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ", str(e))
|
||||
raise ValueError("error")
|
||||
raise e
|
||||
|
||||
# print(response)
|
||||
if response.status_code != 200:
|
||||
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ")
|
||||
raise ValueError("invalid host url")
|
||||
|
||||
pred_y = response.json()['output']
|
||||
# print(f"pred_y len:{len(pred_y)}")
|
||||
# print(f"pred_y:{pred_y}")
|
||||
|
||||
res = {
|
||||
"res": [pred_y]
|
||||
|
@ -54,18 +54,6 @@ class _GPTService(AbstractForecastService):
|
|||
insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows)
|
||||
return res
|
||||
|
||||
# insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
|
||||
#
|
||||
# if self.return_conf:
|
||||
# res1 = [res.tolist(), res.tolist(), res.tolist()], None
|
||||
# else:
|
||||
# res1 = [res.tolist()], None
|
||||
#
|
||||
# # add the conf range if required
|
||||
# return {
|
||||
# "mse": None,
|
||||
# "res": res1
|
||||
# }
|
||||
|
||||
def set_params(self, params):
|
||||
super().set_params(params)
|
||||
|
@ -73,7 +61,7 @@ class _GPTService(AbstractForecastService):
|
|||
if "host" not in params:
|
||||
raise ValueError("gpt service host needs to be specified")
|
||||
|
||||
self.service_host = params['host'].trim()
|
||||
self.service_host = params['host']
|
||||
|
||||
if self.service_host.startswith("https://"):
|
||||
self.service_host = self.service_host.replace("https://", "http://")
|
||||
|
|
|
@ -8,7 +8,7 @@ import pandas as pd
|
|||
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../")
|
||||
|
||||
from taosanalytics.algo.forecast import draw_fc_results
|
||||
from taosanalytics.conf import setup_log_info
|
||||
from taosanalytics.conf import setup_log_info, app_logger
|
||||
from taosanalytics.servicemgmt import loader
|
||||
|
||||
|
||||
|
@ -30,7 +30,8 @@ class ForecastTest(unittest.TestCase):
|
|||
ts_list = data[['Passengers']].index.tolist()
|
||||
dst_list = [int(item.timestamp()) for item in ts_list]
|
||||
|
||||
return data[['Passengers']].values.tolist(), dst_list
|
||||
return data['Passengers'].values.tolist(), dst_list
|
||||
|
||||
|
||||
def test_holt_winters_forecast(self):
|
||||
""" test holt winters forecast with invalid and then valid parameters"""
|
||||
|
@ -111,5 +112,20 @@ class ForecastTest(unittest.TestCase):
|
|||
draw_fc_results(data, len(r["res"]) > 1, r["res"], rows, "arima")
|
||||
|
||||
|
||||
def test_gpt_fc(self):
|
||||
"""for local test only, disabled it in github action"""
|
||||
data, ts = self.get_input_list()
|
||||
pass
|
||||
|
||||
s = loader.get_service("td_gpt_fc")
|
||||
s.set_input_list(data, ts)
|
||||
|
||||
s.set_params({"host":'192.168.2.90:5000/ds_predict', 'fc_rows': 10, 'start_ts': 171000000, 'time_step': 86400*30})
|
||||
r = s.execute()
|
||||
|
||||
rows = len(r["res"][0])
|
||||
draw_fc_results(data, False, r["res"], rows, "gpt")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue