updated gettasklist logic
Former-commit-id: af8ea56c03fc44ba7c881c3bf9c5850765ec2b69
This commit is contained in:
parent
8f381d5030
commit
10e21a9499
|
@ -3,6 +3,8 @@ package ai
|
|||
import (
|
||||
"context"
|
||||
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
|
||||
|
@ -27,12 +29,15 @@ func NewGetCenterTaskListLogic(ctx context.Context, svcCtx *svc.ServiceContext)
|
|||
|
||||
func (l *GetCenterTaskListLogic) GetCenterTaskList() (resp *types.CenterTaskListResp, err error) {
|
||||
resp = &types.CenterTaskListResp{}
|
||||
var mu sync.RWMutex
|
||||
|
||||
adapterList, err := l.svcCtx.Scheduler.AiStorages.GetAdaptersByType("1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l.updateAiTaskStatus(&mu, adapterList)
|
||||
|
||||
for _, adapter := range adapterList {
|
||||
taskList, err := l.svcCtx.Scheduler.AiStorages.GetAiTasksByAdapterId(adapter.Id)
|
||||
if err != nil {
|
||||
|
@ -46,7 +51,11 @@ func (l *GetCenterTaskListLogic) GetCenterTaskList() (resp *types.CenterTaskList
|
|||
if err != nil {
|
||||
elapsed = time.Duration(0)
|
||||
}
|
||||
elapsed = end.Sub(task.CommitTime)
|
||||
start, err := time.ParseInLocation(constants.Layout, task.StartTime, time.Local)
|
||||
if err != nil {
|
||||
elapsed = time.Duration(0)
|
||||
}
|
||||
elapsed = end.Sub(start)
|
||||
case constants.Running:
|
||||
elapsed = time.Now().Sub(task.CommitTime)
|
||||
default:
|
||||
|
@ -64,3 +73,38 @@ func (l *GetCenterTaskListLogic) GetCenterTaskList() (resp *types.CenterTaskList
|
|||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (l *GetCenterTaskListLogic) updateAiTaskStatus(mu *sync.RWMutex, list []*types.AdapterInfo) {
|
||||
var wg sync.WaitGroup
|
||||
for _, adapter := range list {
|
||||
mu.RLock()
|
||||
taskList, err := l.svcCtx.Scheduler.AiStorages.GetAiTasksByAdapterId(adapter.Id)
|
||||
mu.RUnlock()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, task := range taskList {
|
||||
t := task
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
trainingTask, err := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[adapter.Id][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(l.ctx, t.JobId)
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
t.Status = trainingTask.Status
|
||||
t.StartTime = trainingTask.Start
|
||||
t.EndTime = trainingTask.End
|
||||
mu.Lock()
|
||||
err = l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
|
||||
mu.Unlock()
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
return
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
@ -54,11 +54,18 @@ func (l *ScheduleSubmitLogic) ScheduleSubmit(req *types.ScheduleReq) (resp *type
|
|||
|
||||
switch opt.GetOptionType() {
|
||||
case option.AI:
|
||||
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName)
|
||||
rs := (results).([]*schedulers.AiResult)
|
||||
var synergystatus int64
|
||||
if len(rs) > 1 {
|
||||
synergystatus = 1
|
||||
}
|
||||
strategyCode, err := l.svcCtx.Scheduler.AiStorages.GetStrategyCode(req.AiOption.Strategy)
|
||||
|
||||
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName, strategyCode, synergystatus)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs := (results).([]*schedulers.AiResult)
|
||||
|
||||
for _, r := range rs {
|
||||
scheResult := &types.ScheduleResult{}
|
||||
scheResult.ClusterId = r.ClusterId
|
||||
|
|
|
@ -71,13 +71,15 @@ func (s *AiStorage) GetAiTasksByAdapterId(adapterId string) ([]*models.TaskAi, e
|
|||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *AiStorage) SaveTask(name string) (int64, error) {
|
||||
func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64) (int64, error) {
|
||||
// 构建主任务结构体
|
||||
taskModel := models.Task{
|
||||
Status: constants.Saved,
|
||||
Description: "ai task",
|
||||
Name: name,
|
||||
CommitTime: time.Now(),
|
||||
Status: constants.Saved,
|
||||
Description: "ai task",
|
||||
Name: name,
|
||||
SynergyStatus: synergyStatus,
|
||||
Strategy: strategyCode,
|
||||
CommitTime: time.Now(),
|
||||
}
|
||||
// 保存任务数据到数据库
|
||||
tx := s.DbEngin.Create(&taskModel)
|
||||
|
@ -197,10 +199,25 @@ func (s *AiStorage) SaveClusterResources(clusterId string, clusterName string, c
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *AiStorage) UpdateAiTask(task models.TaskAi) error {
|
||||
tx := s.DbEngin.Updates(&task)
|
||||
func (s *AiStorage) UpdateAiTask(task *models.TaskAi) error {
|
||||
tx := s.DbEngin.Updates(task)
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AiStorage) GetStrategyCode(name string) (int64, error) {
|
||||
var strategy int64
|
||||
sqlStr := `select t_dict_item.item_value
|
||||
from t_dict
|
||||
left join t_dict_item on t_dict.id = t_dict_item.dict_id
|
||||
where item_text = ?
|
||||
and t_dict.dict_code = 'schedule_Strategy'`
|
||||
//查询调度策略
|
||||
err := s.DbEngin.Raw(sqlStr, name).Scan(&strategy).Error
|
||||
if err != nil {
|
||||
return strategy, nil
|
||||
}
|
||||
return strategy, nil
|
||||
}
|
||||
|
|
|
@ -174,10 +174,16 @@ func (as *AiScheduler) AssignTask(clusters []*strategy.AssignedCluster) (interfa
|
|||
}
|
||||
|
||||
if len(errs) != 0 {
|
||||
taskId, err := as.AiStorages.SaveTask(as.option.TaskName)
|
||||
var synergystatus int64
|
||||
if len(clusters) > 1 {
|
||||
synergystatus = 1
|
||||
}
|
||||
strategyCode, err := as.AiStorages.GetStrategyCode(as.option.StrategyName)
|
||||
taskId, err := as.AiStorages.SaveTask(as.option.TaskName, strategyCode, synergystatus)
|
||||
if err != nil {
|
||||
return nil, errors.New("database add failed: " + err.Error())
|
||||
}
|
||||
|
||||
var errmsg string
|
||||
for _, err := range errs {
|
||||
e := (err).(struct {
|
||||
|
|
|
@ -17,7 +17,6 @@ package storeLink
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitlink.org.cn/JointCloud/pcm-ac/hpcAC"
|
||||
"gitlink.org.cn/JointCloud/pcm-ac/hpcacclient"
|
||||
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/common"
|
||||
|
@ -475,12 +474,21 @@ func (s *ShuguangAi) GetTrainingTaskLog(ctx context.Context, taskId string, inst
|
|||
}
|
||||
|
||||
func (s *ShuguangAi) GetTrainingTask(ctx context.Context, taskId string) (*collector.Task, error) {
|
||||
task, err := s.QueryTask(ctx, taskId)
|
||||
resp, err := s.QueryTask(ctx, taskId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Println(task)
|
||||
return nil, nil
|
||||
jobresp := (resp).(*hpcAC.GetPytorchTaskResp)
|
||||
if jobresp.Code != "0" {
|
||||
return nil, errors.New(jobresp.Msg)
|
||||
}
|
||||
var task collector.Task
|
||||
task.Id = jobresp.Data.Id
|
||||
task.Start = jobresp.Data.StartTime
|
||||
task.End = jobresp.Data.EndTime
|
||||
task.Status = jobresp.Data.Status
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
func (s *ShuguangAi) Execute(ctx context.Context, option *option.AiOption) (interface{}, error) {
|
||||
|
|
|
@ -27,4 +27,5 @@ const (
|
|||
WaitPause = "WaitPause"
|
||||
WaitStart = "WaitStart"
|
||||
Pending = "Pending"
|
||||
Stopped = "Stopped"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue