Merge pull request 'updated imageinfer api' (#229) from tzwang/pcm-coordinator:master into master

Former-commit-id: a0c5f2c65910568b300c70add2c67c0a88d4ff61
This commit is contained in:
tzwang 2024-06-21 19:31:13 +08:00
commit f600a0c98d
27 changed files with 359 additions and 49 deletions

View File

@ -2,6 +2,18 @@ syntax = "v1"
type ( type (
/******************image inference*************************/ /******************image inference*************************/
ModelTypesResp {
ModelTypes []string `json:"types"`
}
ModelNamesReq {
Type string `form:"type"`
}
ModelNamesResp {
ModelNames []string `json:"models"`
}
/******************image inference*************************/
ImageInferenceReq { ImageInferenceReq {
TaskName string `form:"taskName"` TaskName string `form:"taskName"`
@ -31,4 +43,5 @@ type (
Card string `json:"card"` Card string `json:"card"`
ImageResult string `json:"imageResult"` ImageResult string `json:"imageResult"`
} }
) )

View File

@ -909,6 +909,12 @@ service pcm {
service pcm { service pcm {
@handler ImageInferenceHandler @handler ImageInferenceHandler
post /inference/images (ImageInferenceReq) returns (ImageInferenceResp) post /inference/images (ImageInferenceReq) returns (ImageInferenceResp)
@handler ModelTypesHandler
get /inference/modelTypes returns (ModelTypesResp)
@handler ModelNamesByTypeHandler
get /inference/modelNames (ModelNamesReq) returns (ModelNamesResp)
} }
@server( @server(

View File

@ -1,6 +1,7 @@
Name: pcm.core.api Name: pcm.core.api
Host: 0.0.0.0 Host: 0.0.0.0
Port: 8999 Port: 8999
MaxBytes: 524288000
Timeout: 50000 Timeout: 50000

View File

@ -90,7 +90,7 @@ func UpdateAiTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) {
var wg sync.WaitGroup var wg sync.WaitGroup
for _, aitask := range aiTaskList { for _, aitask := range aiTaskList {
t := aitask t := aitask
if t.Status == constants.Completed || t.Status == constants.Failed { if t.Status == constants.Completed || t.Status == constants.Failed || t.JobId == "" {
continue continue
} }
wg.Add(1) wg.Add(1)

View File

@ -0,0 +1,25 @@
package inference
import (
"net/http"
"github.com/zeromicro/go-zero/rest/httpx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
)
func ModelNamesByTypeHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.ModelNamesReq
if err := httpx.Parse(r, &req); err != nil {
result.ParamErrorResult(r, w, err)
return
}
l := inference.NewModelNamesByTypeLogic(r.Context(), svcCtx)
resp, err := l.ModelNamesByType(&req)
result.HttpResult(r, w, resp, err)
}
}

View File

@ -0,0 +1,16 @@
package inference
import (
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
"net/http"
)
func ModelTypesHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
l := inference.NewModelTypesLogic(r.Context(), svcCtx)
resp, err := l.ModelTypes()
result.HttpResult(r, w, resp, err)
}
}

View File

@ -1143,6 +1143,16 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
Path: "/inference/images", Path: "/inference/images",
Handler: inference.ImageInferenceHandler(serverCtx), Handler: inference.ImageInferenceHandler(serverCtx),
}, },
{
Method: http.MethodGet,
Path: "/inference/modelTypes",
Handler: inference.ModelTypesHandler(serverCtx),
},
{
Method: http.MethodGet,
Path: "/inference/modelNames",
Handler: inference.ModelNamesByTypeHandler(serverCtx),
},
}, },
rest.WithPrefix("/pcm/v1"), rest.WithPrefix("/pcm/v1"),
) )

View File

@ -94,7 +94,7 @@ func (l *CommitGeneralTaskLogic) CommitGeneralTask(req *types.GeneralTaskReq) er
Name: req.Name, Name: req.Name,
CommitTime: time.Now(), CommitTime: time.Now(),
YamlString: strings.Join(req.ReqBody, "\n---\n"), YamlString: strings.Join(req.ReqBody, "\n---\n"),
AdapterTypeDict: 0, AdapterTypeDict: "0",
SynergyStatus: synergyStatus, SynergyStatus: synergyStatus,
Strategy: strategy, Strategy: strategy,
} }

View File

@ -86,7 +86,7 @@ func (l *CommitVmTaskLogic) CommitVmTask(req *types.CommitVmTaskReq) (resp *type
Name: req.Name, Name: req.Name,
CommitTime: time.Now(), CommitTime: time.Now(),
Description: "vm task", Description: "vm task",
AdapterTypeDict: 0, AdapterTypeDict: "0",
SynergyStatus: synergyStatus, SynergyStatus: synergyStatus,
Strategy: strategy, Strategy: strategy,
} }

View File

@ -263,7 +263,7 @@ func (l *PageListTaskLogic) updateAiTaskStatus(tasklist []*types.TaskModel, ch c
var wg sync.WaitGroup var wg sync.WaitGroup
for _, aitask := range aiTaskList { for _, aitask := range aiTaskList {
t := aitask t := aitask
if t.Status == constants.Completed || t.Status == constants.Failed { if t.Status == constants.Completed || t.Status == constants.Failed || t.JobId == "" {
continue continue
} }
wg.Add(1) wg.Add(1)

View File

@ -36,14 +36,14 @@ func (l *TaskDetailsLogic) TaskDetails(req *types.FId) (resp *types.TaskDetailsR
var cList []*types.ClusterInfo var cList []*types.ClusterInfo
var subList []*types.SubTaskInfo var subList []*types.SubTaskInfo
switch task.AdapterTypeDict { switch task.AdapterTypeDict {
case 0: case "0":
l.svcCtx.DbEngin.Table("task_cloud").Where("task_id", task.Id).Scan(&subList) l.svcCtx.DbEngin.Table("task_cloud").Where("task_id", task.Id).Scan(&subList)
if len(subList) <= 0 { if len(subList) <= 0 {
l.svcCtx.DbEngin.Table("task_vm").Where("task_id", task.Id).Find(&subList) l.svcCtx.DbEngin.Table("task_vm").Where("task_id", task.Id).Find(&subList)
} }
case 1: case "1":
l.svcCtx.DbEngin.Table("task_ai").Where("task_id", task.Id).Scan(&subList) l.svcCtx.DbEngin.Table("task_ai").Where("task_id", task.Id).Scan(&subList)
case 2: case "2":
l.svcCtx.DbEngin.Table("task_hpc").Where("task_id", task.Id).Scan(&subList) l.svcCtx.DbEngin.Table("task_hpc").Where("task_id", task.Id).Scan(&subList)
} }
for _, sub := range subList { for _, sub := range subList {

View File

@ -122,7 +122,7 @@ func (l *TaskListLogic) TaskList(req *types.TaskListReq) (resp *types.TaskListRe
func (l *TaskListLogic) updateAitaskStatus(tasks []models.Task, ch chan<- struct{}) { func (l *TaskListLogic) updateAitaskStatus(tasks []models.Task, ch chan<- struct{}) {
for _, task := range tasks { for _, task := range tasks {
if task.AdapterTypeDict != 1 { if task.AdapterTypeDict != "1" {
continue continue
} }
if task.Status == constants.Succeeded { if task.Status == constants.Succeeded {

View File

@ -40,7 +40,7 @@ func (l *CommitHpcTaskLogic) CommitHpcTask(req *types.CommitHpcTaskReq) (resp *t
Strategy: 0, Strategy: 0,
SynergyStatus: 0, SynergyStatus: 0,
CommitTime: time.Now(), CommitTime: time.Now(),
AdapterTypeDict: 2, AdapterTypeDict: "2",
} }
// 保存任务数据到数据库 // 保存任务数据到数据库

View File

@ -10,10 +10,15 @@ import (
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"math/rand" "math/rand"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"sort"
"strconv"
"sync" "sync"
"time"
) )
type ImageInferenceLogic struct { type ImageInferenceLogic struct {
@ -88,7 +93,6 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere
var strat strategy.Strategy var strat strategy.Strategy
switch opt.Strategy { switch opt.Strategy {
case strategy.STATIC_WEIGHT: case strategy.STATIC_WEIGHT:
//todo resources should match cluster StaticWeightMap
strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts))) strat = strategy.NewStaticWeightStrategy(opt.StaticWeightMap, int32(len(ts)))
if err != nil { if err != nil {
return nil, err return nil, err
@ -128,33 +132,69 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
var wg sync.WaitGroup var wg sync.WaitGroup
var cluster_ch = make(chan struct { var cluster_ch = make(chan struct {
urls []*collector.ImageInferUrl urls []*collector.ImageInferUrl
clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
}, len(clusters)) }, len(clusters))
var cs []struct { var cs []struct {
urls []*collector.ImageInferUrl urls []*collector.ImageInferUrl
clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
} }
collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
//save task
var synergystatus int64
if len(clusters) > 1 {
synergystatus = 1
}
strategyCode, err := svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy)
if err != nil {
return nil, err
}
adapterName, err := svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId)
if err != nil {
return nil, err
}
id, err := svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11")
if err != nil {
return nil, err
}
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中")
//save taskai
for _, c := range clusters {
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
opt.Replica = c.Replicas
err := svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "")
if err != nil {
return nil, err
}
}
for _, cluster := range clusters { for _, cluster := range clusters {
wg.Add(1) wg.Add(1)
c := cluster c := cluster
go func() { go func() {
imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt)
if err != nil { if err != nil {
wg.Done()
return return
} }
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
s := struct { s := struct {
urls []*collector.ImageInferUrl urls []*collector.ImageInferUrl
clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
}{ }{
urls: imageUrls, urls: imageUrls,
clusterId: c.ClusterId,
clusterName: clusterName, clusterName: clusterName,
imageNum: c.Replicas, imageNum: c.Replicas,
} }
@ -171,11 +211,42 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
cs = append(cs, s) cs = append(cs, s)
} }
var aiTaskList []*models.TaskAi
tx := svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
if tx.Error != nil {
return nil, tx.Error
}
//change cluster status
if len(clusters) != len(cs) {
var acs []*strategy.AssignedCluster
for _, cluster := range clusters {
if contains(cs, cluster.ClusterId) {
continue
} else {
var ac *strategy.AssignedCluster
ac = cluster
acs = append(acs, ac)
}
}
// update failed cluster status
for _, ac := range acs {
for _, t := range aiTaskList {
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Failed
err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
}
}
}
var result_ch = make(chan *types.ImageResult, len(ts)) var result_ch = make(chan *types.ImageResult, len(ts))
var results []*types.ImageResult var results []*types.ImageResult
wg.Add(len(ts))
var imageNumIdx int32 = 0 var imageNumIdx int32 = 0
var imageNumIdxEnd int32 = 0 var imageNumIdxEnd int32 = 0
for _, c := range cs { for _, c := range cs {
@ -189,16 +260,32 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
new_images = new_images[imageNumIdx:imageNumIdxEnd] new_images = new_images[imageNumIdx:imageNumIdxEnd]
imageNumIdx = imageNumIdx + c.imageNum imageNumIdx = imageNumIdx + c.imageNum
wg.Add(len(new_images))
go sendInferReq(new_images, c, &wg, result_ch) go sendInferReq(new_images, c, &wg, result_ch)
} }
wg.Wait() wg.Wait()
close(result_ch) close(result_ch)
for s := range result_ch { for s := range result_ch {
results = append(results, s) results = append(results, s)
} }
sort.Slice(results, func(p, q int) bool {
return results[p].ClusterName < results[q].ClusterName
})
// update succeeded cluster status
for _, c := range cs {
for _, t := range aiTaskList {
if c.clusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Completed
err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
}
}
return results, nil return results, nil
} }
@ -207,6 +294,7 @@ func sendInferReq(images []struct {
file multipart.File file multipart.File
}, cluster struct { }, cluster struct {
urls []*collector.ImageInferUrl urls []*collector.ImageInferUrl
clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
}, wg *sync.WaitGroup, ch chan<- *types.ImageResult) { }, wg *sync.WaitGroup, ch chan<- *types.ImageResult) {
@ -216,6 +304,7 @@ func sendInferReq(images []struct {
file multipart.File file multipart.File
}, c struct { }, c struct {
urls []*collector.ImageInferUrl urls []*collector.ImageInferUrl
clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
}) { }) {
@ -223,6 +312,8 @@ func sendInferReq(images []struct {
r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName) r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName)
if err != nil { if err != nil {
t.imageResult.ImageResult = err.Error() t.imageResult.ImageResult = err.Error()
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[0].Card
ch <- t.imageResult ch <- t.imageResult
wg.Done() wg.Done()
return return
@ -239,6 +330,8 @@ func sendInferReq(images []struct {
r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName) r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName)
if err != nil { if err != nil {
t.imageResult.ImageResult = err.Error() t.imageResult.ImageResult = err.Error()
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[idx].Card
ch <- t.imageResult ch <- t.imageResult
wg.Done() wg.Done()
return return
@ -257,7 +350,7 @@ func sendInferReq(images []struct {
func getInferResult(url string, file multipart.File, fileName string) (string, error) { func getInferResult(url string, file multipart.File, fileName string) (string, error) {
var res Res var res Res
req := GetACHttpRequest() req := GetRestyRequest(10)
_, err := req. _, err := req.
SetFileReader("file", fileName, file). SetFileReader("file", fileName, file).
SetResult(&res). SetResult(&res).
@ -269,8 +362,8 @@ func getInferResult(url string, file multipart.File, fileName string) (string, e
return res.Result, nil return res.Result, nil
} }
func GetACHttpRequest() *resty.Request { func GetRestyRequest(timeoutSeconds int64) *resty.Request {
client := resty.New() client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second)
request := client.R() request := client.R()
return request return request
} }
@ -278,3 +371,17 @@ func GetACHttpRequest() *resty.Request {
type Res struct { type Res struct {
Result string `json:"result"` Result string `json:"result"`
} }
func contains(cs []struct {
urls []*collector.ImageInferUrl
clusterId string
clusterName string
imageNum int32
}, e string) bool {
for _, c := range cs {
if c.clusterId == e {
return true
}
}
return false
}

View File

@ -0,0 +1,36 @@
package inference
import (
"context"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type ModelNamesByTypeLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewModelNamesByTypeLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ModelNamesByTypeLogic {
return &ModelNamesByTypeLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *ModelNamesByTypeLogic) ModelNamesByType(req *types.ModelNamesReq) (resp *types.ModelNamesResp, err error) {
resp = &types.ModelNamesResp{}
models, err := storeLink.GetModelNamesByType(req.Type)
if err != nil {
logx.Errorf("ModelNamesByType err: %v", err)
return nil, err
}
resp.ModelNames = models
return resp, nil
}

View File

@ -0,0 +1,32 @@
package inference
import (
"context"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type ModelTypesLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewModelTypesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ModelTypesLogic {
return &ModelTypesLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *ModelTypesLogic) ModelTypes() (resp *types.ModelTypesResp, err error) {
resp = &types.ModelTypesResp{}
mTypes := storeLink.GetModelTypes()
resp.ModelTypes = mTypes
return resp, nil
}

View File

@ -71,7 +71,7 @@ func (l *ScheduleSubmitLogic) ScheduleSubmit(req *types.ScheduleReq) (resp *type
if err != nil { if err != nil {
return nil, err return nil, err
} }
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName, strategyCode, synergystatus) id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(req.AiOption.TaskName, strategyCode, synergystatus, "10")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -94,7 +94,7 @@ func (s *AiStorage) GetAiTasksByAdapterId(adapterId string) ([]*models.TaskAi, e
return resp, nil return resp, nil
} }
func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64) (int64, error) { func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64, aiType string) (int64, error) {
// 构建主任务结构体 // 构建主任务结构体
taskModel := models.Task{ taskModel := models.Task{
Status: constants.Saved, Status: constants.Saved,
@ -102,7 +102,8 @@ func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int6
Name: name, Name: name,
SynergyStatus: synergyStatus, SynergyStatus: synergyStatus,
Strategy: strategyCode, Strategy: strategyCode,
AdapterTypeDict: 1, AdapterTypeDict: "1",
TaskTypeDict: aiType,
CommitTime: time.Now(), CommitTime: time.Now(),
} }
// 保存任务数据到数据库 // 保存任务数据到数据库
@ -113,9 +114,22 @@ func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int6
return taskModel.Id, nil return taskModel.Id, nil
} }
func (s *AiStorage) SaveAiTask(taskId int64, option *option.AiOption, adapterName string, clusterId string, clusterName string, jobId string, status string, msg string) error { func (s *AiStorage) SaveAiTask(taskId int64, opt option.Option, adapterName string, clusterId string, clusterName string, jobId string, status string, msg string) error {
var aiOpt *option.AiOption
switch (opt).(type) {
case *option.AiOption:
aiOpt = (opt).(*option.AiOption)
case *option.InferOption:
inferOpt := (opt).(*option.InferOption)
aiOpt = &option.AiOption{}
aiOpt.TaskName = inferOpt.TaskName
aiOpt.Replica = inferOpt.Replica
aiOpt.AdapterId = inferOpt.AdapterId
aiOpt.TaskType = inferOpt.ModelType
aiOpt.StrategyName = inferOpt.Strategy
}
// 构建主任务结构体 // 构建主任务结构体
aId, err := strconv.ParseInt(option.AdapterId, 10, 64) aId, err := strconv.ParseInt(aiOpt.AdapterId, 10, 64)
if err != nil { if err != nil {
return err return err
} }
@ -130,14 +144,14 @@ func (s *AiStorage) SaveAiTask(taskId int64, option *option.AiOption, adapterNam
AdapterName: adapterName, AdapterName: adapterName,
ClusterId: cId, ClusterId: cId,
ClusterName: clusterName, ClusterName: clusterName,
Name: option.TaskName, Name: aiOpt.TaskName,
Replica: int64(option.Replica), Replica: int64(aiOpt.Replica),
JobId: jobId, JobId: jobId,
TaskType: option.TaskType, TaskType: aiOpt.TaskType,
Strategy: option.StrategyName, Strategy: aiOpt.StrategyName,
Status: status, Status: status,
Msg: msg, Msg: msg,
Card: option.ComputeCard, Card: aiOpt.ComputeCard,
CommitTime: time.Now(), CommitTime: time.Now(),
} }
// 保存任务数据到数据库 // 保存任务数据到数据库

View File

@ -222,7 +222,7 @@ func (as *AiScheduler) AssignTask(clusters []*strategy.AssignedCluster) (interfa
synergystatus = 1 synergystatus = 1
} }
strategyCode, err := as.AiStorages.GetStrategyCode(as.option.StrategyName) strategyCode, err := as.AiStorages.GetStrategyCode(as.option.StrategyName)
taskId, err := as.AiStorages.SaveTask(as.option.TaskName, strategyCode, synergystatus) taskId, err := as.AiStorages.SaveTask(as.option.TaskName, strategyCode, synergystatus, "10")
if err != nil { if err != nil {
return nil, errors.New("database add failed: " + err.Error()) return nil, errors.New("database add failed: " + err.Error())
} }

View File

@ -16,3 +16,7 @@ type InferOption struct {
Cmd string `json:"cmd,optional"` Cmd string `json:"cmd,optional"`
Replica int32 `json:"replicas,optional"` Replica int32 `json:"replicas,optional"`
} }
func (a InferOption) GetOptionType() string {
return AI_INFER
}

View File

@ -1,6 +1,7 @@
package option package option
const ( const (
AI_INFER = "ai_infer"
AI = "ai" AI = "ai"
CLOUD = "cloud" CLOUD = "cloud"
HPC = "hpc" HPC = "hpc"

View File

@ -385,10 +385,13 @@ func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.Inf
Type: option.ModelType, Type: option.ModelType,
Card: "npu", Card: "npu",
} }
urlResp, _ := m.modelArtsRpc.ImageReasoningUrl(ctx, urlReq) urlResp, err := m.modelArtsRpc.ImageReasoningUrl(ctx, urlReq)
if err != nil {
return nil, err
}
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.ImageInferUrl{
Url: urlResp.Url, Url: urlResp.Url,
Card: option.ComputeCard, Card: "npu",
} }
imageUrls = append(imageUrls, imageUrl) imageUrls = append(imageUrls, imageUrl)

View File

@ -872,21 +872,28 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec
} }
func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) {
req := &octopus.GetNotebookListReq{
Platform: o.platform,
PageIndex: o.pageIndex,
PageSize: o.pageSize,
}
list, err := o.octopusRpc.GetNotebookList(ctx, req)
if err != nil {
return nil, err
}
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.ImageInferUrl
for _, notebook := range list.Payload.GetNotebooks() {
if strings.Contains(notebook.AlgorithmName, option.ModelName) {
names := strings.Split(notebook.AlgorithmName, UNDERSCORE)
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.ImageInferUrl{
Url: "http://0.0.0.0:8888/image", Url: DOMAIN + notebook.Tasks[0].Url + FORWARD_SLASH + "image",
Card: "mlu", Card: names[2],
}
imageUrl1 := &collector.ImageInferUrl{
Url: "http://0.0.0.0:8888/image",
Card: "gcu",
}
imageUrl2 := &collector.ImageInferUrl{
Url: "http://0.0.0.0:8888/image",
Card: "biv100",
} }
imageUrls = append(imageUrls, imageUrl) imageUrls = append(imageUrls, imageUrl)
imageUrls = append(imageUrls, imageUrl1) } else {
imageUrls = append(imageUrls, imageUrl2) continue
}
}
return imageUrls, nil return imageUrls, nil
} }

View File

@ -739,7 +739,10 @@ func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferO
Card: "dcu", Card: "dcu",
} }
urlResp, _ := s.aCRpc.GetInferUrl(ctx, urlReq) urlResp, err := s.aCRpc.GetInferUrl(ctx, urlReq)
if err != nil {
return nil, err
}
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.ImageInferUrl{
Url: urlResp.Url, Url: urlResp.Url,
Card: option.ComputeCard, Card: option.ComputeCard,

View File

@ -76,6 +76,9 @@ var (
3: "制作完成", 3: "制作完成",
4: "制作失败", 4: "制作失败",
} }
ModelTypeMap = map[string][]string{
"image_recognition": {"imagenet_resnet50"},
}
AITYPE = map[string]string{ AITYPE = map[string]string{
"1": OCTOPUS, "1": OCTOPUS,
"2": MODELARTS, "2": MODELARTS,
@ -128,6 +131,22 @@ func GetResourceTypes() []string {
return resourceTypes return resourceTypes
} }
func GetModelTypes() []string {
var mTypes []string
for k, _ := range ModelTypeMap {
mTypes = append(mTypes, k)
}
return mTypes
}
func GetModelNamesByType(t string) ([]string, error) {
_, ok := ModelTypeMap[t]
if !ok {
return nil, errors.New("model type does not exist")
}
return ModelTypeMap[t], nil
}
func GetDatasetsNames(ctx context.Context, collectorMap map[string]collector.AiCollector) ([]string, error) { func GetDatasetsNames(ctx context.Context, collectorMap map[string]collector.AiCollector) ([]string, error) {
var wg sync.WaitGroup var wg sync.WaitGroup
var errCh = make(chan interface{}, len(collectorMap)) var errCh = make(chan interface{}, len(collectorMap))

View File

@ -5879,6 +5879,18 @@ type Category struct {
Name string `json:"name"` Name string `json:"name"`
} }
type ModelTypesResp struct {
ModelTypes []string `json:"types"`
}
type ModelNamesReq struct {
Type string `form:"type"`
}
type ModelNamesResp struct {
ModelNames []string `json:"models"`
}
type ImageInferenceReq struct { type ImageInferenceReq struct {
TaskName string `form:"taskName"` TaskName string `form:"taskName"`
TaskDesc string `form:"taskDesc"` TaskDesc string `form:"taskDesc"`

View File

@ -49,7 +49,8 @@ type (
Result string `db:"result"` // 作业结果 Result string `db:"result"` // 作业结果
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index"`
NsID string `db:"ns_id"` NsID string `db:"ns_id"`
AdapterTypeDict int `db:"adapter_type_dict"` //任务类型(对应字典表的值) AdapterTypeDict string `db:"adapter_type_dict"` //任务类型(对应字典表的值)
TaskTypeDict string `db:"task_type_dict"`
} }
) )