diff --git a/api/internal/cron/aiCronTask.go b/api/internal/cron/aiCronTask.go index eb295311..3b14cf3c 100644 --- a/api/internal/cron/aiCronTask.go +++ b/api/internal/cron/aiCronTask.go @@ -173,6 +173,12 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { } } + // Update Infer Task Status + if task.TaskTypeDict == 11 { + UpdateInferTaskStatus(svc, task) + return + } + var aiTask []*models.TaskAi tx := svc.DbEngin.Raw("select * from task_ai where `task_id` = ? ", task.Id).Scan(&aiTask) if tx.Error != nil { @@ -500,3 +506,104 @@ func UpdateClusterResource(svc *svc.ServiceContext) { } wg.Wait() } + +func UpdateInferTaskStatus(svc *svc.ServiceContext, task *types.TaskModel) { + var aiTask []*models.TaskAi + tx := svc.DbEngin.Raw("select * from task_ai where `task_id` = ? ", task.Id).Scan(&aiTask) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } + + if len(aiTask) == 0 { + task.Status = constants.Failed + tx = svc.DbEngin.Model(task).Table("task").Where("deleted_at is null").Updates(task) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } + return + } + + if len(aiTask) == 1 { + if aiTask[0].Status == constants.Completed { + task.StartTime = aiTask[0].StartTime + task.EndTime = aiTask[0].EndTime + task.Status = constants.Succeeded + } else { + task.StartTime = aiTask[0].StartTime + task.Status = aiTask[0].Status + } + + task.UpdatedTime = time.Now().Format(constants.Layout) + tx = svc.DbEngin.Model(task).Table("task").Where("deleted_at is null").Updates(task) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } + return + } + + //for i := len(aiTask) - 1; i >= 0; i-- { + // if aiTask[i].StartTime == "" { + // task.Status = aiTask[i].Status + // aiTask = append(aiTask[:i], aiTask[i+1:]...) + // } + //} + // + //if len(aiTask) == 0 { + // task.UpdatedTime = time.Now().Format(constants.Layout) + // tx = svc.DbEngin.Table("task").Model(task).Updates(task) + // if tx.Error != nil { + // logx.Errorf(tx.Error.Error()) + // return + // } + // return + //} + + start, _ := time.ParseInLocation(time.RFC3339, aiTask[0].StartTime, time.Local) + end, _ := time.ParseInLocation(time.RFC3339, aiTask[0].EndTime, time.Local) + var status string + var count int + for _, a := range aiTask { + if a.Status == constants.Failed { + status = a.Status + break + } + + if a.Status == constants.Pending { + status = a.Status + continue + } + + if a.Status == constants.Running { + status = a.Status + continue + } + + if a.Status == constants.Completed { + count++ + continue + } + } + + if count == len(aiTask) { + status = constants.Succeeded + } + + if status == constants.Succeeded { + task.Status = status + task.StartTime = start.Format(time.RFC3339) + task.EndTime = end.Format(time.RFC3339) + } else { + task.Status = status + task.StartTime = start.Format(time.RFC3339) + } + + task.UpdatedTime = time.Now().Format(constants.Layout) + tx = svc.DbEngin.Table("task").Model(task).Updates(task) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } +} diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index f52b580d..5f09a84e 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -22,7 +22,6 @@ import ( "math/rand" "mime/multipart" "net/http" - "sort" "strconv" "sync" "time" @@ -242,9 +241,12 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s //change cluster status if len(clusters) != len(cs) { var acs []*strategy.AssignedCluster + var rcs []*strategy.AssignedCluster for _, cluster := range clusters { if contains(cs, cluster.ClusterId) { - continue + var ac *strategy.AssignedCluster + ac = cluster + rcs = append(rcs, ac) } else { var ac *strategy.AssignedCluster ac = cluster @@ -265,6 +267,29 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s } } } + + // update running cluster status + for _, ac := range rcs { + for _, t := range aiTaskList { + if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) { + t.Status = constants.Running + err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) + if err != nil { + logx.Errorf(tx.Error.Error()) + } + } + } + } + svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") + } else { + for _, t := range aiTaskList { + t.Status = constants.Running + err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t) + if err != nil { + logx.Errorf(tx.Error.Error()) + } + } + svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中") } var result_ch = make(chan *types.ImageResult, len(ts)) @@ -294,19 +319,26 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s results = append(results, s) } + //sort.Slice(results, func(p, q int) bool { + // return results[p].ClusterName < results[q].ClusterName + //}) + //save ai sub tasks for _, r := range results { for _, task := range aiTaskList { if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { - taskAiSub := &models.TaskAiSub{ - Id: task.Id, + taskAiSub := models.TaskAiSub{ + TaskId: id, + TaskName: task.Name, + TaskAiId: task.TaskId, + TaskAiName: task.Name, ImageName: r.ImageName, Result: r.ImageResult, Card: r.Card, ClusterId: task.ClusterId, ClusterName: r.ClusterName, } - tx := svcCtx.DbEngin.Save(&taskAiSub) + tx := svcCtx.DbEngin.Table("task_ai_sub").Create(&taskAiSub) if tx.Error != nil { logx.Errorf(err.Error()) } @@ -314,10 +346,6 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s } } - sort.Slice(results, func(p, q int) bool { - return results[p].ClusterName < results[q].ClusterName - }) - // update succeeded cluster status var successStatusCount int for _, c := range cs { @@ -342,29 +370,6 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败") } - //save ai sub tasks - for _, r := range results { - for _, task := range aiTaskList { - if r.ClusterId == strconv.Itoa(int(task.ClusterId)) { - taskAiSub := models.TaskAiSub{ - TaskId: id, - TaskName: task.Name, - TaskAiId: task.TaskId, - TaskAiName: task.Name, - ImageName: r.ImageName, - Result: r.ImageResult, - Card: r.Card, - ClusterId: task.ClusterId, - ClusterName: r.ClusterName, - } - tx := svcCtx.DbEngin.Table("task_ai_sub").Create(&taskAiSub) - if tx.Error != nil { - logx.Errorf(err.Error()) - } - } - } - } - return results, nil }