fix(gpt): 1. fix error in anomalywindow count; 2. update gpt model name; 3. update the expired test case. 4. support the user specified valid_code

This commit is contained in:
Haojun Liao 2025-03-21 12:52:41 +08:00
parent c400471917
commit 47004861c7
9 changed files with 29 additions and 23 deletions

View File

@ -593,7 +593,7 @@ static int32_t anomalyAggregateBlocks(SOperatorInfo* pOperator) {
for (int32_t r = 0; r < pBlock->info.rows; ++r) {
TSKEY key = tsList[r];
bool keyInWin = (key >= pSupp->curWin.skey && key < pSupp->curWin.ekey);
bool keyInWin = (key >= pSupp->curWin.skey && key <= pSupp->curWin.ekey);
bool lastRow = (r == pBlock->info.rows - 1);
if (keyInWin) {

View File

@ -23,8 +23,8 @@ endi
print =============== show info
sql show anodes full
if $rows != 8 then
print expect 8 , actual $rows
if $rows != 10 then
print expect 10 , actual $rows
return -1
endi

View File

@ -78,4 +78,4 @@ model-dir = /usr/local/taos/taosanode/model/
log-level = DEBUG
# draw the query results
draw-result = 1
draw-result = 0

View File

@ -5,6 +5,7 @@
from matplotlib import pyplot as plt
from taosanalytics.conf import app_logger, conf
from taosanalytics.servicemgmt import loader
from taosanalytics.util import convert_results_to_windows
def do_ad_check(input_list, ts_list, algo_name, params):
@ -22,17 +23,19 @@ def do_ad_check(input_list, ts_list, algo_name, params):
res = s.execute()
n_error = abs(sum(filter(lambda x: x == -1, res)))
n_error = abs(sum(filter(lambda x: x != s.valid_code, res)))
app_logger.log_inst.debug("There are %d in input, and %d anomaly points found: %s",
len(input_list),
n_error,
res)
draw_ad_results(input_list, res, algo_name)
return res
# draw_ad_results(input_list, res, algo_name, s.valid_code)
ano_window = convert_results_to_windows(res, ts_list, s.valid_code)
return res, ano_window
def draw_ad_results(input_list, res, fig_name):
def draw_ad_results(input_list, res, fig_name, valid_code):
""" draw the detected anomaly points """
# not in debug, do not visualize the anomaly detection result
@ -41,8 +44,7 @@ def draw_ad_results(input_list, res, fig_name):
plt.clf()
for index, val in enumerate(res):
if val != -1:
continue
if val != valid_code:
plt.scatter(index, input_list[index], marker='o', color='r', alpha=0.5, s=100, zorder=3)
plt.plot(input_list, label='sample')

View File

@ -10,7 +10,7 @@ from taosanalytics.service import AbstractForecastService
class _GPTService(AbstractForecastService):
name = 'td_gpt_fc'
name = 'TDtsfm_1'
desc = "internal gpt forecast model based on transformer"
def __init__(self):
@ -23,7 +23,6 @@ class _GPTService(AbstractForecastService):
self.std = None
self.threshold = None
self.time_interval = None
self.dir = 'internal-gpt'
def execute(self):

View File

@ -91,9 +91,7 @@ def handle_ad_request():
# 4. do anomaly detection
try:
res_list = do_ad_check(payload[data_index], payload[ts_index], algo, params)
ano_window = convert_results_to_windows(res_list, payload[ts_index])
res_list, ano_window = do_ad_check(payload[data_index], payload[ts_index], algo, params)
result = {"algo": algo, "option": options, "res": ano_window, "rows": len(ano_window)}
app_logger.log_inst.debug("anomaly-detection result: %s", str(result))

View File

@ -51,6 +51,7 @@ class AbstractAnomalyDetectionService(AbstractAnalyticsService, ABC):
inherent from this class"""
def __init__(self):
self.valid_code = 1
super().__init__()
self.type = "anomaly-detection"
@ -58,6 +59,12 @@ class AbstractAnomalyDetectionService(AbstractAnalyticsService, ABC):
""" check if the input list is empty or None """
return (self.list is None) or (len(self.list) == 0)
def set_params(self, params: dict) -> None:
super().set_params(params)
if "valid_code" in params:
self.valid_code = int(params["valid_code"])
class AbstractForecastService(AbstractAnalyticsService, ABC):
"""abstract forecast service, all forecast algorithms class should be inherent from

View File

@ -44,7 +44,7 @@ class AnomalyDetectionTest(unittest.TestCase):
s.set_params({"k": 2})
r = s.execute()
draw_ad_results(AnomalyDetectionTest.input_list, r, "ksigma")
draw_ad_results(AnomalyDetectionTest.input_list, r, "ksigma", s.valid_code)
self.assertEqual(r[-1], -1)
self.assertEqual(len(r), len(AnomalyDetectionTest.input_list))
@ -64,7 +64,7 @@ class AnomalyDetectionTest(unittest.TestCase):
self.assertEqual(1, 0, e)
r = s.execute()
draw_ad_results(AnomalyDetectionTest.input_list, r, "iqr")
draw_ad_results(AnomalyDetectionTest.input_list, r, "iqr", s.valid_code)
self.assertEqual(r[-1], -1)
self.assertEqual(len(r), len(AnomalyDetectionTest.input_list))
@ -82,7 +82,7 @@ class AnomalyDetectionTest(unittest.TestCase):
s.set_params({"alpha": 0.95})
r = s.execute()
draw_ad_results(AnomalyDetectionTest.input_list, r, "grubbs")
draw_ad_results(AnomalyDetectionTest.input_list, r, "grubbs", s.valid_code)
self.assertEqual(r[-1], -1)
self.assertEqual(len(r), len(AnomalyDetectionTest.input_list))
@ -100,7 +100,7 @@ class AnomalyDetectionTest(unittest.TestCase):
s.set_input_list(AnomalyDetectionTest.input_list, None)
r = s.execute()
draw_ad_results(AnomalyDetectionTest.input_list, r, "shesd")
draw_ad_results(AnomalyDetectionTest.input_list, r, "shesd", s.valid_code)
self.assertEqual(r[-1], -1)
@ -116,7 +116,7 @@ class AnomalyDetectionTest(unittest.TestCase):
s.set_input_list(AnomalyDetectionTest.input_list, None)
r = s.execute()
draw_ad_results(AnomalyDetectionTest.input_list, r, "lof")
draw_ad_results(AnomalyDetectionTest.input_list, r, "lof", s.valid_code)
self.assertEqual(r[-1], -1)
self.assertEqual(r[-2], -1)

View File

@ -36,7 +36,7 @@ def validate_pay_load(json_obj):
raise ValueError('invalid schema info, data column is missing')
def convert_results_to_windows(result, ts_list):
def convert_results_to_windows(result, ts_list, valid_code):
"""generate the window according to anomaly detection result"""
skey, ekey = -1, -1
wins = []
@ -45,7 +45,7 @@ def convert_results_to_windows(result, ts_list):
return wins
for index, val in enumerate(result):
if val == -1:
if val != valid_code:
ekey = ts_list[index]
if skey == -1:
skey = ts_list[index]