fix(gpt): support timemoe (#30410)
This commit is contained in:
parent
30402f66a6
commit
daa129dbb1
|
@ -20,10 +20,6 @@ class _GPTService(AbstractForecastService):
|
||||||
self.service_host = 'http://127.0.0.1:5000/tdtsfm'
|
self.service_host = 'http://127.0.0.1:5000/tdtsfm'
|
||||||
self.headers = {'Content-Type': 'application/json'}
|
self.headers = {'Content-Type': 'application/json'}
|
||||||
|
|
||||||
self.std = None
|
|
||||||
self.threshold = None
|
|
||||||
self.time_interval = None
|
|
||||||
|
|
||||||
|
|
||||||
def execute(self):
|
def execute(self):
|
||||||
if self.list is None or len(self.list) < self.period:
|
if self.list is None or len(self.list) < self.period:
|
||||||
|
|
|
@ -1,81 +0,0 @@
|
||||||
# encoding:utf-8
|
|
||||||
# pylint: disable=c0103
|
|
||||||
""" auto encoder algorithms to detect anomaly for time series data"""
|
|
||||||
import os.path
|
|
||||||
|
|
||||||
import keras
|
|
||||||
|
|
||||||
from taosanalytics.algo.forecast import insert_ts_list
|
|
||||||
from taosanalytics.conf import app_logger, conf
|
|
||||||
from taosanalytics.service import AbstractForecastService
|
|
||||||
|
|
||||||
|
|
||||||
class _LSTMService(AbstractForecastService):
|
|
||||||
name = 'sample_forecast_model'
|
|
||||||
desc = "sample forecast model based on LSTM"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.table_name = None
|
|
||||||
self.mean = None
|
|
||||||
self.std = None
|
|
||||||
self.threshold = None
|
|
||||||
self.time_interval = None
|
|
||||||
self.model = None
|
|
||||||
self.dir = 'sample-fc-lstm'
|
|
||||||
|
|
||||||
self.root_path = conf.get_model_directory()
|
|
||||||
|
|
||||||
self.root_path = self.root_path + f'/{self.dir}/'
|
|
||||||
|
|
||||||
if not os.path.exists(self.root_path):
|
|
||||||
app_logger.log_inst.error(
|
|
||||||
"%s ad algorithm failed to locate default module directory:"
|
|
||||||
"%s, not active", self.__class__.__name__, self.root_path)
|
|
||||||
else:
|
|
||||||
app_logger.log_inst.info("%s ad algorithm root path is: %s", self.__class__.__name__,
|
|
||||||
self.root_path)
|
|
||||||
|
|
||||||
def execute(self):
|
|
||||||
if self.input_is_empty():
|
|
||||||
return []
|
|
||||||
|
|
||||||
if self.model is None:
|
|
||||||
raise FileNotFoundError("not load autoencoder model yet, or load model failed")
|
|
||||||
|
|
||||||
res = self.model.predict(self.list)
|
|
||||||
|
|
||||||
insert_ts_list(res, self.start_ts, self.time_step, self.rows)
|
|
||||||
|
|
||||||
if self.return_conf:
|
|
||||||
res1 = [res.tolist(), res.tolist(), res.tolist()], None
|
|
||||||
else:
|
|
||||||
res1 = [res.tolist()], None
|
|
||||||
|
|
||||||
# add the conf range if required
|
|
||||||
return {
|
|
||||||
"mse": None,
|
|
||||||
"res": res1
|
|
||||||
}
|
|
||||||
|
|
||||||
def set_params(self, params):
|
|
||||||
|
|
||||||
if "model" not in params:
|
|
||||||
raise ValueError("model needs to be specified")
|
|
||||||
|
|
||||||
name = params['model']
|
|
||||||
|
|
||||||
module_file_path = f'{self.root_path}/{name}.keras'
|
|
||||||
# module_info_path = f'{self.root_path}/{name}.info'
|
|
||||||
|
|
||||||
app_logger.log_inst.info("try to load module:%s", module_file_path)
|
|
||||||
|
|
||||||
if os.path.exists(module_file_path):
|
|
||||||
self.model = keras.models.load_model(module_file_path)
|
|
||||||
else:
|
|
||||||
app_logger.log_inst.error("failed to load LSTM model file: %s", module_file_path)
|
|
||||||
raise FileNotFoundError(f"{module_file_path} not found")
|
|
||||||
|
|
||||||
def get_params(self):
|
|
||||||
return {"dir": self.dir + '/*'}
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
# encoding:utf-8
|
||||||
|
# pylint: disable=c0103
|
||||||
|
""" auto encoder algorithms to detect anomaly for time series data"""
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from taosanalytics.algo.forecast import insert_ts_list
|
||||||
|
from taosanalytics.conf import app_logger, conf
|
||||||
|
from taosanalytics.service import AbstractForecastService
|
||||||
|
|
||||||
|
|
||||||
|
class _TimeMOEService(AbstractForecastService):
|
||||||
|
name = 'timemoe'
|
||||||
|
desc = ("Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts; "
|
||||||
|
"Ref to https://github.com/Time-MoE/Time-MoE")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.table_name = None
|
||||||
|
self.service_host = 'http://127.0.0.1:5001/timemoe'
|
||||||
|
self.headers = {'Content-Type': 'application/json'}
|
||||||
|
|
||||||
|
|
||||||
|
def execute(self):
|
||||||
|
if self.list is None or len(self.list) < self.period:
|
||||||
|
raise ValueError("number of input data is less than the periods")
|
||||||
|
|
||||||
|
if self.rows <= 0:
|
||||||
|
raise ValueError("fc rows is not specified yet")
|
||||||
|
|
||||||
|
# let's request the gpt service
|
||||||
|
data = {"input": self.list, 'next_len': self.rows}
|
||||||
|
try:
|
||||||
|
response = requests.post(self.service_host, data=json.dumps(data), headers=self.headers)
|
||||||
|
except Exception as e:
|
||||||
|
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ", str(e))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if response.status_code == 404:
|
||||||
|
app_logger.log_inst.error(f"failed to connect the service: {self.service_host} ")
|
||||||
|
raise ValueError("invalid host url")
|
||||||
|
elif response.status_code != 200:
|
||||||
|
app_logger.log_inst.error(f"failed to request the service: {self.service_host}, reason: {response.text}")
|
||||||
|
raise ValueError(f"failed to request the service, {response.text}")
|
||||||
|
|
||||||
|
pred_y = response.json()['output']
|
||||||
|
|
||||||
|
res = {
|
||||||
|
"res": [pred_y]
|
||||||
|
}
|
||||||
|
|
||||||
|
insert_ts_list(res["res"], self.start_ts, self.time_step, self.rows)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def set_params(self, params):
|
||||||
|
super().set_params(params)
|
||||||
|
|
||||||
|
if "host" in params:
|
||||||
|
self.service_host = params['host']
|
||||||
|
|
||||||
|
if self.service_host.startswith("https://"):
|
||||||
|
self.service_host = self.service_host.replace("https://", "http://")
|
||||||
|
elif "http://" not in self.service_host:
|
||||||
|
self.service_host = "http://" + self.service_host
|
||||||
|
|
||||||
|
app_logger.log_inst.info("%s specify gpt host service: %s", self.__class__.__name__,
|
||||||
|
self.service_host)
|
||||||
|
|
|
@ -57,23 +57,30 @@ def list_all_models():
|
||||||
def handle_ad_request():
|
def handle_ad_request():
|
||||||
"""handle the anomaly detection requests"""
|
"""handle the anomaly detection requests"""
|
||||||
app_logger.log_inst.info('recv ad request from %s', request.remote_addr)
|
app_logger.log_inst.info('recv ad request from %s', request.remote_addr)
|
||||||
app_logger.log_inst.debug('req payload: %s', request.json)
|
|
||||||
|
|
||||||
algo = request.json["algo"].lower() if "algo" in request.json else "ksigma"
|
try:
|
||||||
|
req_json = request.json
|
||||||
|
except Exception as e:
|
||||||
|
app_logger.log_inst.error('invalid json format, %s, %s', e, request.data)
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
app_logger.log_inst.debug('req payload: %s', req_json)
|
||||||
|
|
||||||
|
algo = req_json["algo"].lower() if "algo" in req_json else "ksigma"
|
||||||
|
|
||||||
# 1. validate the input data in json format
|
# 1. validate the input data in json format
|
||||||
try:
|
try:
|
||||||
validate_pay_load(request.json)
|
validate_pay_load(req_json)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return {"msg": str(e), "rows": -1}
|
return {"msg": str(e), "rows": -1}
|
||||||
|
|
||||||
payload = request.json["data"]
|
payload = req_json["data"]
|
||||||
|
|
||||||
# 2. white noise data check
|
# 2. white noise data check
|
||||||
wn_check = request.json["wncheck"] if "wncheck" in request.json else 1
|
wn_check = req_json["wncheck"] if "wncheck" in req_json else 1
|
||||||
|
|
||||||
data_index = get_data_index(request.json["schema"])
|
data_index = get_data_index(req_json["schema"])
|
||||||
ts_index = get_ts_index(request.json["schema"])
|
ts_index = get_ts_index(req_json["schema"])
|
||||||
|
|
||||||
if wn_check:
|
if wn_check:
|
||||||
try:
|
try:
|
||||||
|
@ -86,7 +93,7 @@ def handle_ad_request():
|
||||||
|
|
||||||
# 3. parse the options for different ad services
|
# 3. parse the options for different ad services
|
||||||
# the default options is like following: "algo=ksigma,k=2,invalid_option=44"
|
# the default options is like following: "algo=ksigma,k=2,invalid_option=44"
|
||||||
options = request.json["option"] if "option" in request.json else None
|
options = req_json["option"] if "option" in req_json else None
|
||||||
params = parse_options(options)
|
params = parse_options(options)
|
||||||
|
|
||||||
# 4. do anomaly detection
|
# 4. do anomaly detection
|
||||||
|
@ -108,24 +115,31 @@ def handle_ad_request():
|
||||||
def handle_forecast_req():
|
def handle_forecast_req():
|
||||||
"""handle the fc request """
|
"""handle the fc request """
|
||||||
app_logger.log_inst.info('recv fc from %s', request.remote_addr)
|
app_logger.log_inst.info('recv fc from %s', request.remote_addr)
|
||||||
app_logger.log_inst.debug('req payload: %s', request.json)
|
|
||||||
|
try:
|
||||||
|
req_json = request.json
|
||||||
|
except Exception as e:
|
||||||
|
app_logger.log_inst.error('forecast recv invalid json format, %s, %s', e, request.data)
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
app_logger.log_inst.debug('req payload: %s', req_json)
|
||||||
|
|
||||||
# holt-winters by default
|
# holt-winters by default
|
||||||
algo = request.json['algo'].lower() if 'algo' in request.json else 'holtwinters'
|
algo = req_json['algo'].lower() if 'algo' in req_json else 'holtwinters'
|
||||||
|
|
||||||
# 1. validate the input data in json format
|
# 1. validate the input data in json format
|
||||||
try:
|
try:
|
||||||
validate_pay_load(request.json)
|
validate_pay_load(req_json)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
app_logger.log_inst.error('validate req json failed, %s', e)
|
app_logger.log_inst.error('validate req json failed, %s', e)
|
||||||
return {"msg": str(e), "rows": -1}
|
raise ValueError(e)
|
||||||
|
|
||||||
payload = request.json["data"]
|
payload = req_json["data"]
|
||||||
|
|
||||||
# 2. white noise data check
|
# 2. white noise data check
|
||||||
wn_check = request.json["wncheck"] if "wncheck" in request.json else 1
|
wn_check = req_json["wncheck"] if "wncheck" in req_json else 1
|
||||||
data_index = get_data_index(request.json["schema"])
|
data_index = get_data_index(req_json["schema"])
|
||||||
ts_index = get_ts_index(request.json["schema"])
|
ts_index = get_ts_index(req_json["schema"])
|
||||||
|
|
||||||
if wn_check:
|
if wn_check:
|
||||||
try:
|
try:
|
||||||
|
@ -136,11 +150,11 @@ def handle_forecast_req():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"msg": str(e), "rows": -1}
|
return {"msg": str(e), "rows": -1}
|
||||||
|
|
||||||
options = request.json["option"] if "option" in request.json else None
|
options = req_json["option"] if "option" in req_json else None
|
||||||
params = parse_options(options)
|
params = parse_options(options)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
do_add_fc_params(params, request.json)
|
do_add_fc_params(params, req_json)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
app_logger.log_inst.error("invalid fc params: %s", e)
|
app_logger.log_inst.error("invalid fc params: %s", e)
|
||||||
return {"msg": f"{e}", "rows": -1}
|
return {"msg": f"{e}", "rows": -1}
|
||||||
|
|
Loading…
Reference in New Issue