fix(gpt): add host into script (#30440)

This commit is contained in:
Haojun Liao 2025-03-25 14:15:52 +08:00 committed by GitHub
parent 13475700c5
commit 0162bfa222
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 4 deletions

View File

@ -79,3 +79,8 @@ log-level = DEBUG
# draw the query results
draw-result = 0
[tsfm-service]
# moe default service host
tdtsfm_1 = http://127.0.0.1:5000/tdtsfm
timemoe-fc = http://127.0.0.1:5001/timemoe

View File

@ -11,15 +11,20 @@ from taosanalytics.service import AbstractForecastService
class _GPTService(AbstractForecastService):
name = 'tdtsfm_1'
desc = "internal gpt forecast model based on transformer"
desc = "Time-Series Foundation Model based on transformer by TAOS DATA"
def __init__(self):
super().__init__()
self.table_name = None
self.service_host = 'http://127.0.0.1:5000/tdtsfm'
self.headers = {'Content-Type': 'application/json'}
service_host = conf.get_tsfm_service("tdtsfm_1")
if service_host is not None:
self.service_host = service_host
else:
self.service_host = 'http://127.0.0.1:5000/tdtsfm'
def execute(self):
if self.list is None or len(self.list) < self.period:

View File

@ -10,7 +10,7 @@ from taosanalytics.service import AbstractForecastService
class _TimeMOEService(AbstractForecastService):
name = 'timemoe'
name = 'timemoe-fc'
desc = ("Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts; "
"Ref to https://github.com/Time-MoE/Time-MoE")
@ -18,7 +18,13 @@ class _TimeMOEService(AbstractForecastService):
super().__init__()
self.table_name = None
service_host = conf.get_tsfm_service("timemoe-fc")
if service_host is not None:
self.service_host = service_host
else:
self.service_host = 'http://127.0.0.1:5001/timemoe'
self.headers = {'Content-Type': 'application/json'}

View File

@ -33,6 +33,12 @@ class Configure:
""" return model directory """
return self._model_directory
def get_tsfm_service(self, service_name):
if self.conf.has_option("tsfm-service", service_name):
return self.conf.get("tsfm-service", service_name)
else:
return None
def get_draw_result_option(self):
""" get the option for draw results or not"""
return self._draw_result