fix: fix flake8 check

This commit is contained in:
jiajingbin 2025-03-18 17:59:43 +08:00
parent 10aa645982
commit 181a4bb0fa
1 changed files with 16 additions and 4 deletions

View File

@ -97,6 +97,9 @@ class TaosConfig(PretrainedConfig):
self.Taos_max_position_embeddings = Taos_max_position_embeddings self.Taos_max_position_embeddings = Taos_max_position_embeddings
super().__init__(**kwargs) super().__init__(**kwargs)
class BaseStreamer:
pass
class TaosTSGenerationMixin(GenerationMixin): class TaosTSGenerationMixin(GenerationMixin):
@torch.no_grad() @torch.no_grad()
def generate( def generate(
@ -593,7 +596,7 @@ class TaosForPrediction(TaosPreTrainedModel, TaosTSGenerationMixin):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if revin: if revin:
mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std(dim=-1, keepdim=True) mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std(dim=-1, keepdim=True)
input_ids = (input_ids - mean) / std input_ids = (input_ids - mean) / std
@ -636,7 +639,7 @@ class TaosForPrediction(TaosPreTrainedModel, TaosTSGenerationMixin):
if output_token_len > max_output_length: if output_token_len > max_output_length:
predictions = predictions[:, :max_output_length] predictions = predictions[:, :max_output_length]
if revin: if revin:
predictions = predictions * std + mean predictions = predictions * std + mean
if not return_dict: if not return_dict:
output = (predictions,) + outputs[1:] output = (predictions,) + outputs[1:]
return (loss) + output if loss is not None else output return (loss) + output if loss is not None else output
@ -718,18 +721,27 @@ def init_model():
Taos_hidden_act="gelu", Taos_hidden_act="gelu",
Taos_output_token_lens=[96], Taos_output_token_lens=[96],
) )
Taos_model = TaosForPrediction(config) Taos_model = TaosForPrediction(config)
model_path = "taos.pth" model_path = "taos.pth"
Taos_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True) Taos_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
Taos_model = Taos_model.to(device) Taos_model = Taos_model.to(device)
#src_data = load_data('src_data/holt-winters_1.txt') #src_data = load_data('src_data/holt-winters_1.txt')
#print(f"src_data:{src_data}") #print(f"src_data:{src_data}")
#seqs = torch.tensor(src_data).unsqueeze(0).float() #seqs = torch.tensor(src_data).unsqueeze(0).float()
#seqs = seqs.to(device) #seqs = seqs.to(device)
print(Taos_model) print(Taos_model)
def train():
pass
def infer():
pass
def data_view():
pass
@app.route('/get_train_data', methods=['POST']) @app.route('/get_train_data', methods=['POST'])
def get_train_data(): def get_train_data():
try: try: