diff --git a/api/internal/cron/aiTask.go b/api/internal/cron/aiTask.go new file mode 100644 index 00000000..bb970c74 --- /dev/null +++ b/api/internal/cron/aiTask.go @@ -0,0 +1,329 @@ +package cron + +import ( + "errors" + "fmt" + "github.com/zeromicro/go-zero/core/logx" + "github.com/zeromicro/go-zero/zrpc" + "gitlink.org.cn/JointCloud/pcm-ac/hpcacclient" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/config" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/executor" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "gitlink.org.cn/JointCloud/pcm-modelarts/client/imagesservice" + "gitlink.org.cn/JointCloud/pcm-modelarts/client/modelartsservice" + "gitlink.org.cn/JointCloud/pcm-octopus/octopusclient" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "net/http" + "strconv" + "sync" + "time" +) + +const ( + OCTOPUS = "octopus" + MODELARTS = "modelarts" + SHUGUANGAI = "shuguangAi" +) + +func GetTaskList(svc *svc.ServiceContext) ([]*types.TaskModel, error) { + limit := 10 + offset := 0 + var list []*types.TaskModel + db := svc.DbEngin.Model(&types.TaskModel{}).Table("task") + + db = db.Where("deleted_at is null") + + //count total + var total int64 + err := db.Count(&total).Error + db.Limit(limit).Offset(offset) + + if err != nil { + return nil, err + } + err = db.Order("created_time desc").Find(&list).Error + if err != nil { + return nil, err + } + return list, nil +} + +func UpdateAiTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { + list := make([]*types.TaskModel, len(tasklist)) + copy(list, tasklist) + for i := len(list) - 1; i >= 0; i-- { + if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + list = append(list[:i], list[i+1:]...) + } + } + + if len(list) == 0 { + return + } + + task := list[0] + for i := range list { + earliest, _ := time.Parse(constants.Layout, task.UpdatedTime) + latest, _ := time.Parse(constants.Layout, list[i].UpdatedTime) + if latest.Before(earliest) { + task = list[i] + } + } + + var aiTaskList []*models.TaskAi + tx := svc.DbEngin.Raw("select * from task_ai where `task_id` = ? ", task.Id).Scan(&aiTaskList) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } + + if len(aiTaskList) == 0 { + return + } + + var wg sync.WaitGroup + for _, aitask := range aiTaskList { + t := aitask + if t.Status == constants.Completed || t.Status == constants.Failed { + continue + } + wg.Add(1) + go func() { + h := http.Request{} + trainingTask, err := svc.Scheduler.AiService.AiCollectorAdapterMap[strconv.FormatInt(t.AdapterId, 10)][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(h.Context(), t.JobId) + if err != nil { + if status.Code(err) == codes.DeadlineExceeded { + msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) + logx.Errorf(errors.New(msg).Error()) + wg.Done() + return + } + + msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) + logx.Errorf(errors.New(msg).Error()) + wg.Done() + return + } + if trainingTask == nil { + wg.Done() + return + } + t.Status = trainingTask.Status + t.StartTime = trainingTask.Start + t.EndTime = trainingTask.End + err = svc.Scheduler.AiStorages.UpdateAiTask(t) + if err != nil { + msg := fmt.Sprintf("###UpdateAiTaskStatus###, AiTaskId: %v, clusterId: %v , JobId: %v, error: %v \n", t.Id, t.ClusterId, t.JobId, err.Error()) + logx.Errorf(errors.New(msg).Error()) + wg.Done() + return + } + wg.Done() + }() + } + wg.Wait() +} + +func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) { + list := make([]*types.TaskModel, len(tasklist)) + copy(list, tasklist) + for i := len(list) - 1; i >= 0; i-- { + if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { + list = append(list[:i], list[i+1:]...) + } + } + + if len(list) == 0 { + return + } + + task := list[0] + for i := range list { + earliest, _ := time.Parse(time.RFC3339, task.UpdatedTime) + latest, _ := time.Parse(time.RFC3339, list[i].UpdatedTime) + if latest.Before(earliest) { + task = list[i] + } + } + + var aiTask []*models.TaskAi + tx := svc.DbEngin.Raw("select * from task_ai where `task_id` = ? ", task.Id).Scan(&aiTask) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } + + if len(aiTask) == 0 { + tx = svc.DbEngin.Model(task).Table("task").Where("deleted_at is null").Updates(task) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } + return + } + + if len(aiTask) == 1 { + if aiTask[0].Status == constants.Completed { + task.Status = constants.Succeeded + } else { + task.Status = aiTask[0].Status + } + task.StartTime = aiTask[0].StartTime + task.EndTime = aiTask[0].EndTime + task.UpdatedTime = time.Now().Format(constants.Layout) + tx = svc.DbEngin.Model(task).Table("task").Where("deleted_at is null").Updates(task) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } + return + } + + for i := len(aiTask) - 1; i >= 0; i-- { + if aiTask[i].StartTime == "" { + task.Status = aiTask[i].Status + aiTask = append(aiTask[:i], aiTask[i+1:]...) + } + } + + if len(aiTask) == 0 { + task.UpdatedTime = time.Now().Format(constants.Layout) + tx = svc.DbEngin.Table("task").Model(task).Updates(task) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } + return + } + + start, _ := time.ParseInLocation(constants.Layout, aiTask[0].StartTime, time.Local) + end, _ := time.ParseInLocation(constants.Layout, aiTask[0].EndTime, time.Local) + + var status string + var count int + for _, a := range aiTask { + s, _ := time.ParseInLocation(constants.Layout, a.StartTime, time.Local) + e, _ := time.ParseInLocation(constants.Layout, a.EndTime, time.Local) + + if s.Before(start) { + start = s + } + + if e.After(end) { + end = e + } + + if a.Status == constants.Failed { + status = a.Status + break + } + + if a.Status == constants.Pending { + status = a.Status + continue + } + + if a.Status == constants.Running { + status = a.Status + continue + } + + if a.Status == constants.Completed { + count++ + continue + } + } + + if count == len(aiTask) { + status = constants.Succeeded + } + + if status != "" { + task.Status = status + task.StartTime = start.Format(constants.Layout) + task.EndTime = end.Format(constants.Layout) + } + + task.UpdatedTime = time.Now().Format(constants.Layout) + tx = svc.DbEngin.Table("task").Model(task).Updates(task) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return + } +} + +func UpdateAiAdapterMaps(svc *svc.ServiceContext) { + var aiType = "1" + adapterIds, err := svc.Scheduler.AiStorages.GetAdapterIdsByType(aiType) + if err != nil { + msg := fmt.Sprintf("###UpdateAiAdapterMaps###, error: %v \n", err.Error()) + logx.Errorf(errors.New(msg).Error()) + return + } + if len(adapterIds) == 0 { + return + } + + for _, id := range adapterIds { + if isAdapterExist(svc, id) { + continue + } + clusters, err := svc.Scheduler.AiStorages.GetClustersByAdapterId(id) + if err != nil { + msg := fmt.Sprintf("###UpdateAiAdapterMaps###, error: %v \n", err.Error()) + logx.Errorf(errors.New(msg).Error()) + return + } + if len(clusters.List) == 0 { + continue + } + exeClusterMap, colClusterMap := InitAiClusterMap(&svc.Config, clusters.List) + svc.Scheduler.AiService.AiExecutorAdapterMap[id] = exeClusterMap + svc.Scheduler.AiService.AiCollectorAdapterMap[id] = colClusterMap + } +} + +func isAdapterExist(svc *svc.ServiceContext, id string) bool { + _, ok := svc.Scheduler.AiService.AiExecutorAdapterMap[id] + _, ok2 := svc.Scheduler.AiService.AiCollectorAdapterMap[id] + if ok && ok2 { + return true + } + return false +} + +func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[string]executor.AiExecutor, map[string]collector.AiCollector) { + executorMap := make(map[string]executor.AiExecutor) + collectorMap := make(map[string]collector.AiCollector) + for _, c := range clusters { + switch c.Name { + case OCTOPUS: + id, _ := strconv.ParseInt(c.Id, 10, 64) + octopusRpc := octopusclient.NewOctopus(zrpc.MustNewClient(conf.OctopusRpcConf)) + octopus := storeLink.NewOctopusLink(octopusRpc, c.Nickname, id) + collectorMap[c.Id] = octopus + executorMap[c.Id] = octopus + case MODELARTS: + id, _ := strconv.ParseInt(c.Id, 10, 64) + modelArtsRpc := modelartsservice.NewModelArtsService(zrpc.MustNewClient(conf.ModelArtsRpcConf)) + modelArtsImgRpc := imagesservice.NewImagesService(zrpc.MustNewClient(conf.ModelArtsImgRpcConf)) + modelarts := storeLink.NewModelArtsLink(modelArtsRpc, modelArtsImgRpc, c.Name, id, c.Nickname) + collectorMap[c.Id] = modelarts + executorMap[c.Id] = modelarts + case SHUGUANGAI: + id, _ := strconv.ParseInt(c.Id, 10, 64) + aCRpc := hpcacclient.NewHpcAC(zrpc.MustNewClient(conf.ACRpcConf)) + sgai := storeLink.NewShuguangAi(aCRpc, c.Nickname, id) + collectorMap[c.Id] = sgai + executorMap[c.Id] = sgai + } + } + + return executorMap, collectorMap +} diff --git a/api/internal/cron/cron.go b/api/internal/cron/cron.go index 7ad0f7b0..be7de1a7 100644 --- a/api/internal/cron/cron.go +++ b/api/internal/cron/cron.go @@ -15,6 +15,7 @@ package cron import ( + "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" ) @@ -29,4 +30,18 @@ func AddCronGroup(svc *svc.ServiceContext) { SyncParticipantRpc(svc) }) + svc.Cron.AddFunc("*/5 * * * * ?", func() { + list, err := GetTaskList(svc) + if err != nil { + logx.Errorf(err.Error()) + return + } + UpdateTaskStatus(svc, list) + UpdateAiTaskStatus(svc, list) + }) + + svc.Cron.AddFunc("*/5 * * * * ?", func() { + UpdateAiAdapterMaps(svc) + }) + } diff --git a/api/internal/logic/core/pagelisttasklogic.go b/api/internal/logic/core/pagelisttasklogic.go index 108b3f78..96ccf437 100644 --- a/api/internal/logic/core/pagelisttasklogic.go +++ b/api/internal/logic/core/pagelisttasklogic.go @@ -80,6 +80,7 @@ func (l *PageListTaskLogic) PageListTask(req *types.PageTaskReq) (resp *types.Pa for _, ch := range chs { select { case <-ch: + case <-time.After(1 * time.Second): } } return diff --git a/api/internal/scheduler/service/aiService.go b/api/internal/scheduler/service/aiService.go index 0567e4a3..3b34bd5a 100644 --- a/api/internal/scheduler/service/aiService.go +++ b/api/internal/scheduler/service/aiService.go @@ -13,6 +13,7 @@ import ( "gitlink.org.cn/JointCloud/pcm-modelarts/client/modelartsservice" "gitlink.org.cn/JointCloud/pcm-octopus/octopusclient" "strconv" + "sync" ) const ( @@ -24,6 +25,8 @@ const ( type AiService struct { AiExecutorAdapterMap map[string]map[string]executor.AiExecutor AiCollectorAdapterMap map[string]map[string]collector.AiCollector + Storage *database.AiStorage + mu sync.Mutex } func NewAiService(conf *config.Config, storages *database.AiStorage) (*AiService, error) { @@ -35,12 +38,16 @@ func NewAiService(conf *config.Config, storages *database.AiStorage) (*AiService aiService := &AiService{ AiExecutorAdapterMap: make(map[string]map[string]executor.AiExecutor), AiCollectorAdapterMap: make(map[string]map[string]collector.AiCollector), + Storage: storages, } for _, id := range adapterIds { clusters, err := storages.GetClustersByAdapterId(id) if err != nil { return nil, err } + if len(clusters.List) == 0 { + continue + } exeClusterMap, colClusterMap := InitAiClusterMap(conf, clusters.List) aiService.AiExecutorAdapterMap[id] = exeClusterMap aiService.AiCollectorAdapterMap[id] = colClusterMap @@ -78,3 +85,11 @@ func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[st return executorMap, collectorMap } + +//func (a *AiService) AddCluster() error { +// +//} +// +//func (a *AiService) AddAdapter() error { +// +//}