fix: fix flake8 check
This commit is contained in:
parent
10aa645982
commit
181a4bb0fa
|
@ -97,6 +97,9 @@ class TaosConfig(PretrainedConfig):
|
||||||
self.Taos_max_position_embeddings = Taos_max_position_embeddings
|
self.Taos_max_position_embeddings = Taos_max_position_embeddings
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
class BaseStreamer:
|
||||||
|
pass
|
||||||
|
|
||||||
class TaosTSGenerationMixin(GenerationMixin):
|
class TaosTSGenerationMixin(GenerationMixin):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate(
|
def generate(
|
||||||
|
@ -593,7 +596,7 @@ class TaosForPrediction(TaosPreTrainedModel, TaosTSGenerationMixin):
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if revin:
|
if revin:
|
||||||
mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std(dim=-1, keepdim=True)
|
mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std(dim=-1, keepdim=True)
|
||||||
input_ids = (input_ids - mean) / std
|
input_ids = (input_ids - mean) / std
|
||||||
|
@ -636,7 +639,7 @@ class TaosForPrediction(TaosPreTrainedModel, TaosTSGenerationMixin):
|
||||||
if output_token_len > max_output_length:
|
if output_token_len > max_output_length:
|
||||||
predictions = predictions[:, :max_output_length]
|
predictions = predictions[:, :max_output_length]
|
||||||
if revin:
|
if revin:
|
||||||
predictions = predictions * std + mean
|
predictions = predictions * std + mean
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (predictions,) + outputs[1:]
|
output = (predictions,) + outputs[1:]
|
||||||
return (loss) + output if loss is not None else output
|
return (loss) + output if loss is not None else output
|
||||||
|
@ -718,18 +721,27 @@ def init_model():
|
||||||
Taos_hidden_act="gelu",
|
Taos_hidden_act="gelu",
|
||||||
Taos_output_token_lens=[96],
|
Taos_output_token_lens=[96],
|
||||||
)
|
)
|
||||||
|
|
||||||
Taos_model = TaosForPrediction(config)
|
Taos_model = TaosForPrediction(config)
|
||||||
model_path = "taos.pth"
|
model_path = "taos.pth"
|
||||||
Taos_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
|
Taos_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=True)
|
||||||
Taos_model = Taos_model.to(device)
|
Taos_model = Taos_model.to(device)
|
||||||
|
|
||||||
#src_data = load_data('src_data/holt-winters_1.txt')
|
#src_data = load_data('src_data/holt-winters_1.txt')
|
||||||
#print(f"src_data:{src_data}")
|
#print(f"src_data:{src_data}")
|
||||||
#seqs = torch.tensor(src_data).unsqueeze(0).float()
|
#seqs = torch.tensor(src_data).unsqueeze(0).float()
|
||||||
#seqs = seqs.to(device)
|
#seqs = seqs.to(device)
|
||||||
print(Taos_model)
|
print(Taos_model)
|
||||||
|
|
||||||
|
def train():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def infer():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def data_view():
|
||||||
|
pass
|
||||||
|
|
||||||
@app.route('/get_train_data', methods=['POST'])
|
@app.route('/get_train_data', methods=['POST'])
|
||||||
def get_train_data():
|
def get_train_data():
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue