Former-commit-id: 803dbd4a6580a9de6f37157eb99438e65ec4c9bd
This commit is contained in:
zhangwei 2024-06-25 09:32:37 +08:00
commit ba46c23ecd
2 changed files with 144 additions and 32 deletions

View File

@ -173,6 +173,12 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) {
}
}
// Update Infer Task Status
if task.TaskTypeDict == 11 {
UpdateInferTaskStatus(svc, task)
return
}
var aiTask []*models.TaskAi
tx := svc.DbEngin.Raw("select * from task_ai where `task_id` = ? ", task.Id).Scan(&aiTask)
if tx.Error != nil {
@ -500,3 +506,104 @@ func UpdateClusterResource(svc *svc.ServiceContext) {
}
wg.Wait()
}
func UpdateInferTaskStatus(svc *svc.ServiceContext, task *types.TaskModel) {
var aiTask []*models.TaskAi
tx := svc.DbEngin.Raw("select * from task_ai where `task_id` = ? ", task.Id).Scan(&aiTask)
if tx.Error != nil {
logx.Errorf(tx.Error.Error())
return
}
if len(aiTask) == 0 {
task.Status = constants.Failed
tx = svc.DbEngin.Model(task).Table("task").Where("deleted_at is null").Updates(task)
if tx.Error != nil {
logx.Errorf(tx.Error.Error())
return
}
return
}
if len(aiTask) == 1 {
if aiTask[0].Status == constants.Completed {
task.StartTime = aiTask[0].StartTime
task.EndTime = aiTask[0].EndTime
task.Status = constants.Succeeded
} else {
task.StartTime = aiTask[0].StartTime
task.Status = aiTask[0].Status
}
task.UpdatedTime = time.Now().Format(constants.Layout)
tx = svc.DbEngin.Model(task).Table("task").Where("deleted_at is null").Updates(task)
if tx.Error != nil {
logx.Errorf(tx.Error.Error())
return
}
return
}
//for i := len(aiTask) - 1; i >= 0; i-- {
// if aiTask[i].StartTime == "" {
// task.Status = aiTask[i].Status
// aiTask = append(aiTask[:i], aiTask[i+1:]...)
// }
//}
//
//if len(aiTask) == 0 {
// task.UpdatedTime = time.Now().Format(constants.Layout)
// tx = svc.DbEngin.Table("task").Model(task).Updates(task)
// if tx.Error != nil {
// logx.Errorf(tx.Error.Error())
// return
// }
// return
//}
start, _ := time.ParseInLocation(time.RFC3339, aiTask[0].StartTime, time.Local)
end, _ := time.ParseInLocation(time.RFC3339, aiTask[0].EndTime, time.Local)
var status string
var count int
for _, a := range aiTask {
if a.Status == constants.Failed {
status = a.Status
break
}
if a.Status == constants.Pending {
status = a.Status
continue
}
if a.Status == constants.Running {
status = a.Status
continue
}
if a.Status == constants.Completed {
count++
continue
}
}
if count == len(aiTask) {
status = constants.Succeeded
}
if status == constants.Succeeded {
task.Status = status
task.StartTime = start.Format(time.RFC3339)
task.EndTime = end.Format(time.RFC3339)
} else {
task.Status = status
task.StartTime = start.Format(time.RFC3339)
}
task.UpdatedTime = time.Now().Format(constants.Layout)
tx = svc.DbEngin.Table("task").Model(task).Updates(task)
if tx.Error != nil {
logx.Errorf(tx.Error.Error())
return
}
}

View File

@ -22,7 +22,6 @@ import (
"math/rand"
"mime/multipart"
"net/http"
"sort"
"strconv"
"sync"
"time"
@ -242,9 +241,12 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
//change cluster status
if len(clusters) != len(cs) {
var acs []*strategy.AssignedCluster
var rcs []*strategy.AssignedCluster
for _, cluster := range clusters {
if contains(cs, cluster.ClusterId) {
continue
var ac *strategy.AssignedCluster
ac = cluster
rcs = append(rcs, ac)
} else {
var ac *strategy.AssignedCluster
ac = cluster
@ -265,6 +267,29 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
}
}
}
// update running cluster status
for _, ac := range rcs {
for _, t := range aiTaskList {
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Running
err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
}
}
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
} else {
for _, t := range aiTaskList {
t.Status = constants.Running
err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中")
}
var result_ch = make(chan *types.ImageResult, len(ts))
@ -294,19 +319,26 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
results = append(results, s)
}
//sort.Slice(results, func(p, q int) bool {
// return results[p].ClusterName < results[q].ClusterName
//})
//save ai sub tasks
for _, r := range results {
for _, task := range aiTaskList {
if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
taskAiSub := &models.TaskAiSub{
Id: task.Id,
taskAiSub := models.TaskAiSub{
TaskId: id,
TaskName: task.Name,
TaskAiId: task.TaskId,
TaskAiName: task.Name,
ImageName: r.ImageName,
Result: r.ImageResult,
Card: r.Card,
ClusterId: task.ClusterId,
ClusterName: r.ClusterName,
}
tx := svcCtx.DbEngin.Save(&taskAiSub)
tx := svcCtx.DbEngin.Table("task_ai_sub").Create(&taskAiSub)
if tx.Error != nil {
logx.Errorf(err.Error())
}
@ -314,10 +346,6 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
}
}
sort.Slice(results, func(p, q int) bool {
return results[p].ClusterName < results[q].ClusterName
})
// update succeeded cluster status
var successStatusCount int
for _, c := range cs {
@ -342,29 +370,6 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
}
//save ai sub tasks
for _, r := range results {
for _, task := range aiTaskList {
if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
taskAiSub := models.TaskAiSub{
TaskId: id,
TaskName: task.Name,
TaskAiId: task.TaskId,
TaskAiName: task.Name,
ImageName: r.ImageName,
Result: r.ImageResult,
Card: r.Card,
ClusterId: task.ClusterId,
ClusterName: r.ClusterName,
}
tx := svcCtx.DbEngin.Table("task_ai_sub").Create(&taskAiSub)
if tx.Error != nil {
logx.Errorf(err.Error())
}
}
}
}
return results, nil
}