fix(gpt): add gpt (#30273)

* fix(stream): support packaging enterprise edition.

* feat(gpt): support lstm and do some internal refactor, add sample autoencoder model.

* feat(gpt): support lstm and do some internal refactor, add sample autoencoder model.

* test(gpt): disable model case.

* test(gpt): disable model case.

* doc: fix title error in doc.

* doc: add mlp doc.

* fix(gpt): add gpt

* fix(gpt): update the test cases.
This commit is contained in:
Haojun Liao 2025-03-19 16:48:01 +08:00 committed by GitHub
parent 9d8496264f
commit 99bd700fb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 86 additions and 1 deletions

View File

@ -0,0 +1,85 @@
# encoding:utf-8
# pylint: disable=c0103
""" auto encoder algorithms to detect anomaly for time series data"""
import json
import requests
from taosanalytics.algo.forecast import insert_ts_list
from taosanalytics.conf import app_logger, conf
from taosanalytics.service import AbstractForecastService
class _GPTService(AbstractForecastService):
name = 'td_gpt_fc'
desc = "internal gpt forecast model based on transformer"
def __init__(self):
super().__init__()
self.table_name = None
self.service_host = 'http://192.168.2.90:5000/ds_predict'
self.headers = {'Content-Type': 'application/json'}
self.std = None
self.threshold = None
self.time_interval = None
self.dir = 'internal-gpt'
def execute(self):
if self.list is None or len(self.list) < self.period:
raise ValueError("number of input data is less than the periods")
if self.fc_rows <= 0:
raise ValueError("fc rows is not specified yet")
# let's request the gpt service
data = {"input": self.list, 'next_len': self.fc_rows}
try:
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
except Exception as e:
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ", str(e))
raise ValueError("error")
# print(response)
pred_y = response.json()['output']
# print(f"pred_y len:{len(pred_y)}")
# print(f"pred_y:{pred_y}")
res = {
"res": [pred_y]
}
insert_ts_list(res["res"], self.start_ts, self.time_step, self.fc_rows)
return res
# insert_ts_list(res, self.start_ts, self.time_step, self.fc_rows)
#
# if self.return_conf:
# res1 = [res.tolist(), res.tolist(), res.tolist()], None
# else:
# res1 = [res.tolist()], None
#
# # add the conf range if required
# return {
# "mse": None,
# "res": res1
# }
def set_params(self, params):
super().set_params(params)
if "host" not in params:
raise ValueError("gpt service host needs to be specified")
self.service_host = params['host'].trim()
if self.service_host.startswith("https://"):
self.service_host = self.service_host.replace("https://", "http://")
elif "http://" not in self.service_host:
self.service_host = "http://" + self.service_host
app_logger.log_inst.info("%s specify gpt host service: %s", self.__class__.__name__,
self.service_host)

View File

@ -99,7 +99,7 @@ class ServiceTest(unittest.TestCase):
if item["type"] == "anomaly-detection":
self.assertEqual(len(item["algo"]), 6)
else:
self.assertEqual(len(item["algo"]), 3)
self.assertEqual(len(item["algo"]), 4)
if __name__ == '__main__':