From 181a4bb0fa5b77aeda8504c711a5ad3aad8166d1 Mon Sep 17 00:00:00 2001 From: jiajingbin Date: Tue, 18 Mar 2025 17:59:43 +0800 Subject: [PATCH] fix: fix flake8 check --- .../tdgpt/dockerfile/tdgpt/taos_ts_server.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tools/tdgpt/dockerfile/tdgpt/taos_ts_server.py b/tools/tdgpt/dockerfile/tdgpt/taos_ts_server.py index 1836475f81..68dbdc14e1 100644 --- a/tools/tdgpt/dockerfile/tdgpt/taos_ts_server.py +++ b/tools/tdgpt/dockerfile/tdgpt/taos_ts_server.py @@ -97,6 +97,9 @@ class TaosConfig(PretrainedConfig): self.Taos_max_position_embeddings = Taos_max_position_embeddings super().__init__(**kwargs) +class BaseStreamer: + pass + class TaosTSGenerationMixin(GenerationMixin): @torch.no_grad() def generate( @@ -593,7 +596,7 @@ class TaosForPrediction(TaosPreTrainedModel, TaosTSGenerationMixin): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + if revin: mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std(dim=-1, keepdim=True) input_ids = (input_ids - mean) / std @@ -636,7 +639,7 @@ class TaosForPrediction(TaosPreTrainedModel, TaosTSGenerationMixin): if output_token_len > max_output_length: predictions = predictions[:, :max_output_length] if revin: - predictions = predictions * std + mean + predictions = predictions * std + mean if not return_dict: output = (predictions,) + outputs[1:] return (loss) + output if loss is not None else output @@ -718,18 +721,27 @@ def init_model(): Taos_hidden_act="gelu", Taos_output_token_lens=[96], ) - + Taos_model = TaosForPrediction(config) model_path = "taos.pth" Taos_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True) Taos_model = Taos_model.to(device) - + #src_data = load_data('src_data/holt-winters_1.txt') #print(f"src_data:{src_data}") #seqs = torch.tensor(src_data).unsqueeze(0).float() #seqs = seqs.to(device) print(Taos_model) +def train(): + pass + +def infer(): + pass + +def data_view(): + pass + @app.route('/get_train_data', methods=['POST']) def get_train_data(): try: