Merge pull request 'fix bugs' (#166) from tzwang/pcm-coordinator:master into master

Former-commit-id: 3842e14c1e64dd1ca3fddec6f68c6e7ba6cff331
This commit is contained in:
tzwang 2024-05-11 18:44:08 +08:00
commit a0181ba9d6
2 changed files with 100 additions and 10 deletions

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils/timeutils" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils/timeutils"
"time" "time"
@ -49,6 +51,11 @@ func (l *PageListTaskLogic) PageListTask(req *types.PageTaskReq) (resp *types.Pa
if err != nil { if err != nil {
return nil, result.NewDefaultError(err.Error()) return nil, result.NewDefaultError(err.Error())
} }
// 更新智算任务状态
var ch = make(chan struct{})
go l.updateAitaskStatus(list, ch)
for _, model := range list { for _, model := range list {
if model.StartTime != "" && model.EndTime == "" { if model.StartTime != "" && model.EndTime == "" {
startTime := timeutils.TimeStringToGoTime(model.StartTime) startTime := timeutils.TimeStringToGoTime(model.StartTime)
@ -65,5 +72,64 @@ func (l *PageListTaskLogic) PageListTask(req *types.PageTaskReq) (resp *types.Pa
resp.PageNum = req.PageNum resp.PageNum = req.PageNum
resp.Total = total resp.Total = total
select {
case _ = <-ch:
return resp, nil return resp, nil
case <-time.After(1 * time.Second):
return resp, nil
}
}
func (l *PageListTaskLogic) updateAitaskStatus(tasks []*types.TaskModel, ch chan<- struct{}) {
for _, task := range tasks {
if task.AdapterTypeDict != 1 {
continue
}
if task.Status == constants.Succeeded {
continue
}
var aiTask []*models.TaskAi
tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", task.Id).Scan(&aiTask)
if tx.Error != nil {
logx.Errorf(tx.Error.Error())
return
}
start, _ := time.ParseInLocation(constants.Layout, aiTask[0].StartTime, time.Local)
end, _ := time.ParseInLocation(constants.Layout, aiTask[0].EndTime, time.Local)
var status = constants.Succeeded
for _, a := range aiTask {
s, _ := time.ParseInLocation(constants.Layout, a.StartTime, time.Local)
e, _ := time.ParseInLocation(constants.Layout, a.EndTime, time.Local)
if s.Before(start) {
start = s
}
if e.After(end) {
end = e
}
if a.Status == constants.Failed {
status = a.Status
break
}
if a.Status == constants.Running {
status = a.Status
continue
}
}
task.Status = status
task.StartTime = start.Format(constants.Layout)
task.EndTime = end.Format(constants.Layout)
tx = l.svcCtx.DbEngin.Table("task").Updates(task)
if tx.Error != nil {
return
}
}
ch <- struct{}{}
} }

View File

@ -53,7 +53,7 @@ const (
EnflameT20 = 128 EnflameT20 = 128
BASE_TOPS = 128 BASE_TOPS = 128
CAMBRICON = "cambricon" CAMBRICON = "cambricon"
TIANSHU = "天数" ILUVATAR = "iluvatar"
TRAIN_CMD = "cd /code; python train.py" TRAIN_CMD = "cd /code; python train.py"
VERSION = "V1" VERSION = "V1"
DOMAIN = "http://192.168.242.41:8001/" DOMAIN = "http://192.168.242.41:8001/"
@ -63,7 +63,7 @@ var (
cardAliasMap = map[string]string{ cardAliasMap = map[string]string{
MLU: CAMBRICON, MLU: CAMBRICON,
GCU: ENFLAME, GCU: ENFLAME,
BIV100: TIANSHU, BIV100: ILUVATAR,
} }
cardTopsMap = map[string]float64{ cardTopsMap = map[string]float64{
MLU: CAMBRICONMLU290, MLU: CAMBRICONMLU290,
@ -373,20 +373,37 @@ func (o *OctopusLink) DownloadAlgorithmCode(ctx context.Context, resourceType st
} }
var algorithmId string var algorithmId string
var algorithms []*octopus.Algorithms
for _, a := range resp.Payload.Algorithms { for _, a := range resp.Payload.Algorithms {
if strings.ToLower(a.FrameworkName) != taskType { if strings.ToLower(a.FrameworkName) != taskType {
continue continue
} }
if a.AlgorithmName == name {
algorithmId = a.AlgorithmId if a.AlgorithmDescript == name {
break algorithms = append(algorithms, a)
} }
} }
if algorithmId == "" { if len(algorithms) == 0 {
return "", errors.New("algorithmId not found") return "", errors.New("algorithmId not found")
} }
if len(algorithms) == 1 {
algorithmId = algorithms[0].AlgorithmId
}
aLatest := &octopus.Algorithms{}
for i, _ := range algorithms {
if time.Unix(aLatest.CreatedAt, 0).After(time.Unix(algorithms[i].CreatedAt, 0)) {
aLatest = algorithms[i]
}
}
if aLatest.AlgorithmId == "" {
return "", errors.New("algorithmId not found")
}
algorithmId = aLatest.AlgorithmId
dcReq := &octopus.DownloadCompressReq{ dcReq := &octopus.DownloadCompressReq{
Platform: o.platform, Platform: o.platform,
Version: VERSION, Version: VERSION,
@ -428,6 +445,13 @@ func (o *OctopusLink) DownloadAlgorithmCode(ctx context.Context, resourceType st
} }
func (o *OctopusLink) UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error { func (o *OctopusLink) UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error {
//var name string
//if resourceType == CARD {
// name = dataset + UNDERSCORE + algorithm + UNDERSCORE + card
//} else {
// name = dataset + UNDERSCORE + algorithm + UNDERSCORE + CPU
//}
//uploadReq := &octopus.UploadAlgorithmReq{}
return nil return nil
} }
@ -455,9 +479,9 @@ func (o *OctopusLink) GetTrainingTask(ctx context.Context, taskId string) (*coll
if err != nil { if err != nil {
return nil, err return nil, err
} }
jobresp := (resp).(*octopus.GetTrainJobResp) jobresp, ok := (resp).(*octopus.GetTrainJobResp)
if !jobresp.Success { if !jobresp.Success || !ok {
return nil, errors.New(jobresp.Error.Message) return nil, errors.New("get training task failed")
} }
var task collector.Task var task collector.Task
task.Id = jobresp.Payload.TrainJob.Id task.Id = jobresp.Payload.TrainJob.Id