diff --git a/tools/tdgpt/cfg/taosanode.ini b/tools/tdgpt/cfg/taosanode.ini index 2107710a30..5fc4c08bd3 100755 --- a/tools/tdgpt/cfg/taosanode.ini +++ b/tools/tdgpt/cfg/taosanode.ini @@ -79,3 +79,8 @@ log-level = DEBUG # draw the query results draw-result = 0 + +[tsfm-service] +# moe default service host +tdtsfm_1 = http://127.0.0.1:5000/tdtsfm +timemoe-fc = http://127.0.0.1:5001/timemoe diff --git a/tools/tdgpt/taosanalytics/algo/fc/gpt.py b/tools/tdgpt/taosanalytics/algo/fc/gpt.py index f364897dba..5019f30e29 100644 --- a/tools/tdgpt/taosanalytics/algo/fc/gpt.py +++ b/tools/tdgpt/taosanalytics/algo/fc/gpt.py @@ -11,15 +11,20 @@ from taosanalytics.service import AbstractForecastService class _GPTService(AbstractForecastService): name = 'tdtsfm_1' - desc = "internal gpt forecast model based on transformer" + desc = "Time-Series Foundation Model based on transformer by TAOS DATA" def __init__(self): super().__init__() self.table_name = None - self.service_host = 'http://127.0.0.1:5000/tdtsfm' self.headers = {'Content-Type': 'application/json'} + service_host = conf.get_tsfm_service("tdtsfm_1") + if service_host is not None: + self.service_host = service_host + else: + self.service_host = 'http://127.0.0.1:5000/tdtsfm' + def execute(self): if self.list is None or len(self.list) < self.period: diff --git a/tools/tdgpt/taosanalytics/algo/fc/timemoe.py b/tools/tdgpt/taosanalytics/algo/fc/timemoe.py index 4e1ea29a9a..f27b097acb 100644 --- a/tools/tdgpt/taosanalytics/algo/fc/timemoe.py +++ b/tools/tdgpt/taosanalytics/algo/fc/timemoe.py @@ -10,7 +10,7 @@ from taosanalytics.service import AbstractForecastService class _TimeMOEService(AbstractForecastService): - name = 'timemoe' + name = 'timemoe-fc' desc = ("Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts; " "Ref to https://github.com/Time-MoE/Time-MoE") @@ -18,7 +18,13 @@ class _TimeMOEService(AbstractForecastService): super().__init__() self.table_name = None - self.service_host = 'http://127.0.0.1:5001/timemoe' + + service_host = conf.get_tsfm_service("timemoe-fc") + if service_host is not None: + self.service_host = service_host + else: + self.service_host = 'http://127.0.0.1:5001/timemoe' + self.headers = {'Content-Type': 'application/json'} diff --git a/tools/tdgpt/taosanalytics/conf.py b/tools/tdgpt/taosanalytics/conf.py index c255b8e258..f0af5539ee 100644 --- a/tools/tdgpt/taosanalytics/conf.py +++ b/tools/tdgpt/taosanalytics/conf.py @@ -33,6 +33,12 @@ class Configure: """ return model directory """ return self._model_directory + def get_tsfm_service(self, service_name): + if self.conf.has_option("tsfm-service", service_name): + return self.conf.get("tsfm-service", service_name) + else: + return None + def get_draw_result_option(self): """ get the option for draw results or not""" return self._draw_result