diff --git a/tools/tdgpt/taosanalytics/algo/fc/gpt.py b/tools/tdgpt/taosanalytics/algo/fc/gpt.py index a279630722..898591cce4 100644 --- a/tools/tdgpt/taosanalytics/algo/fc/gpt.py +++ b/tools/tdgpt/taosanalytics/algo/fc/gpt.py @@ -39,13 +39,13 @@ class _GPTService(AbstractForecastService): response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers) except Exception as e: app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ", str(e)) - raise ValueError("error") + raise e - # print(response) + if response.status_code != 200: + app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ") + raise ValueError("invalid host url") pred_y = response.json()['output'] - # print(f"pred_y len:{len(pred_y)}") - # print(f"pred_y:{pred_y}") res = { "res": [pred_y] @@ -54,18 +54,6 @@ class _GPTService(AbstractForecastService): insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows) return res - # insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows) - # - # if self.return_conf: - # res1 = [res.tolist(), res.tolist(), res.tolist()], None - # else: - # res1 = [res.tolist()], None - # - # # add the conf range if required - # return { - # "mse": None, - # "res": res1 - # } def set_params(self, params): super().set_params(params) @@ -73,7 +61,7 @@ class _GPTService(AbstractForecastService): if "host" not in params: raise ValueError("gpt service host needs to be specified") - self.service_host = params['host'].trim() + self.service_host = params['host'] if self.service_host.startswith("https://"): self.service_host = self.service_host.replace("https://", "http://") diff --git a/tools/tdgpt/taosanalytics/test/forecast_test.py b/tools/tdgpt/taosanalytics/test/forecast_test.py index 1e4874b8c8..6951c700b6 100644 --- a/tools/tdgpt/taosanalytics/test/forecast_test.py +++ b/tools/tdgpt/taosanalytics/test/forecast_test.py @@ -8,7 +8,7 @@ import pandas as pd sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/../../") from taosanalytics.algo.forecast import draw_fc_results -from taosanalytics.conf import setup_log_info +from taosanalytics.conf import setup_log_info, app_logger from taosanalytics.servicemgmt import loader @@ -30,7 +30,8 @@ class ForecastTest(unittest.TestCase): ts_list = data[['Passengers']].index.tolist() dst_list = [int(item.timestamp()) for item in ts_list] - return data[['Passengers']].values.tolist(), dst_list + return data['Passengers'].values.tolist(), dst_list + def test_holt_winters_forecast(self): """ test holt winters forecast with invalid and then valid parameters""" @@ -111,5 +112,20 @@ class ForecastTest(unittest.TestCase): draw_fc_results(data, len(r["res"]) > 1, r["res"], rows, "arima") + def test_gpt_fc(self): + """for local test only, disabled it in github action""" + data, ts = self.get_input_list() + pass + + s = loader.get_service("td_gpt_fc") + s.set_input_list(data, ts) + + s.set_params({"host":'192.168.2.90:5000/ds_predict', 'fc_rows': 10, 'start_ts': 171000000, 'time_step': 86400*30}) + r = s.execute() + + rows = len(r["res"][0]) + draw_fc_results(data, False, r["res"], rows, "gpt") + + if __name__ == '__main__': unittest.main()