Merge pull request 'updated gettasklist logic' (#139) from tzwang/pcm-coordinator:master into master

Former-commit-id: 58e1dc340dbbc3229d6092c68dc9b8282f8e21a6
This commit is contained in:
tzwang 2024-05-07 20:09:21 +08:00
commit 70212751f5
6 changed files with 98 additions and 15 deletions

View File

@ -3,6 +3,8 @@ package ai
import (
"context"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"strconv"
"sync"
"time"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
@ -27,12 +29,15 @@ func NewGetCenterTaskListLogic(ctx context.Context, svcCtx *svc.ServiceContext)
func (l *GetCenterTaskListLogic) GetCenterTaskList() (resp *types.CenterTaskListResp, err error) {
resp = &types.CenterTaskListResp{}
var mu sync.RWMutex
adapterList, err := l.svcCtx.Scheduler.AiStorages.GetAdaptersByType("1")
if err != nil {
return nil, err
}
l.updateAiTaskStatus(&mu, adapterList)
for _, adapter := range adapterList {
taskList, err := l.svcCtx.Scheduler.AiStorages.GetAiTasksByAdapterId(adapter.Id)
if err != nil {
@ -46,7 +51,11 @@ func (l *GetCenterTaskListLogic) GetCenterTaskList() (resp *types.CenterTaskList
if err != nil {
elapsed = time.Duration(0)
}
elapsed = end.Sub(task.CommitTime)
start, err := time.ParseInLocation(constants.Layout, task.StartTime, time.Local)
if err != nil {
elapsed = time.Duration(0)
}
elapsed = end.Sub(start)
case constants.Running:
elapsed = time.Now().Sub(task.CommitTime)
default:
@ -64,3 +73,38 @@ func (l *GetCenterTaskListLogic) GetCenterTaskList() (resp *types.CenterTaskList
return resp, nil
}
func (l *GetCenterTaskListLogic) updateAiTaskStatus(mu *sync.RWMutex, list []*types.AdapterInfo) {
var wg sync.WaitGroup
for _, adapter := range list {
mu.RLock()
taskList, err := l.svcCtx.Scheduler.AiStorages.GetAiTasksByAdapterId(adapter.Id)
mu.RUnlock()
if err != nil {
continue
}
for _, task := range taskList {
t := task
wg.Add(1)
go func() {
trainingTask, err := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[adapter.Id][strconv.FormatInt(t.ClusterId, 10)].GetTrainingTask(l.ctx, t.JobId)
if err != nil {
wg.Done()
return
}
t.Status = trainingTask.Status
t.StartTime = trainingTask.Start
t.EndTime = trainingTask.End
mu.Lock()
err = l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
mu.Unlock()
if err != nil {
wg.Done()
return
}
wg.Done()
}()
}
}
wg.Wait()
}

View File

@ -54,11 +54,18 @@ func (l *ScheduleSubmitLogic) ScheduleSubmit(req *types.ScheduleReq) (resp *type
switch opt.GetOptionType() {
case option.AI:
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName)
rs := (results).([]*schedulers.AiResult)
var synergystatus int64
if len(rs) > 1 {
synergystatus = 1
}
strategyCode, err := l.svcCtx.Scheduler.AiStorages.GetStrategyCode(req.AiOption.Strategy)
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName, strategyCode, synergystatus)
if err != nil {
return nil, err
}
rs := (results).([]*schedulers.AiResult)
for _, r := range rs {
scheResult := &types.ScheduleResult{}
scheResult.ClusterId = r.ClusterId

View File

@ -71,12 +71,14 @@ func (s *AiStorage) GetAiTasksByAdapterId(adapterId string) ([]*models.TaskAi, e
return resp, nil
}
func (s *AiStorage) SaveTask(name string) (int64, error) {
func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64) (int64, error) {
// 构建主任务结构体
taskModel := models.Task{
Status: constants.Saved,
Description: "ai task",
Name: name,
SynergyStatus: synergyStatus,
Strategy: strategyCode,
CommitTime: time.Now(),
}
// 保存任务数据到数据库
@ -197,10 +199,25 @@ func (s *AiStorage) SaveClusterResources(clusterId string, clusterName string, c
return nil
}
func (s *AiStorage) UpdateAiTask(task models.TaskAi) error {
tx := s.DbEngin.Updates(&task)
func (s *AiStorage) UpdateAiTask(task *models.TaskAi) error {
tx := s.DbEngin.Updates(task)
if tx.Error != nil {
return tx.Error
}
return nil
}
func (s *AiStorage) GetStrategyCode(name string) (int64, error) {
var strategy int64
sqlStr := `select t_dict_item.item_value
from t_dict
left join t_dict_item on t_dict.id = t_dict_item.dict_id
where item_text = ?
and t_dict.dict_code = 'schedule_Strategy'`
//查询调度策略
err := s.DbEngin.Raw(sqlStr, name).Scan(&strategy).Error
if err != nil {
return strategy, nil
}
return strategy, nil
}

View File

@ -174,10 +174,16 @@ func (as *AiScheduler) AssignTask(clusters []*strategy.AssignedCluster) (interfa
}
if len(errs) != 0 {
taskId, err := as.AiStorages.SaveTask(as.option.TaskName)
var synergystatus int64
if len(clusters) > 1 {
synergystatus = 1
}
strategyCode, err := as.AiStorages.GetStrategyCode(as.option.StrategyName)
taskId, err := as.AiStorages.SaveTask(as.option.TaskName, strategyCode, synergystatus)
if err != nil {
return nil, errors.New("database add failed: " + err.Error())
}
var errmsg string
for _, err := range errs {
e := (err).(struct {

View File

@ -17,7 +17,6 @@ package storeLink
import (
"context"
"errors"
"fmt"
"gitlink.org.cn/JointCloud/pcm-ac/hpcAC"
"gitlink.org.cn/JointCloud/pcm-ac/hpcacclient"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/common"
@ -475,12 +474,21 @@ func (s *ShuguangAi) GetTrainingTaskLog(ctx context.Context, taskId string, inst
}
func (s *ShuguangAi) GetTrainingTask(ctx context.Context, taskId string) (*collector.Task, error) {
task, err := s.QueryTask(ctx, taskId)
resp, err := s.QueryTask(ctx, taskId)
if err != nil {
return nil, err
}
fmt.Println(task)
return nil, nil
jobresp := (resp).(*hpcAC.GetPytorchTaskResp)
if jobresp.Code != "0" {
return nil, errors.New(jobresp.Msg)
}
var task collector.Task
task.Id = jobresp.Data.Id
task.Start = jobresp.Data.StartTime
task.End = jobresp.Data.EndTime
task.Status = jobresp.Data.Status
return &task, nil
}
func (s *ShuguangAi) Execute(ctx context.Context, option *option.AiOption) (interface{}, error) {

View File

@ -27,4 +27,5 @@ const (
WaitPause = "WaitPause"
WaitStart = "WaitStart"
Pending = "Pending"
Stopped = "Stopped"
)