diff --git a/desc/inference/inference.api b/desc/inference/inference.api index a77c0e98..a868a6da 100644 --- a/desc/inference/inference.api +++ b/desc/inference/inference.api @@ -20,7 +20,7 @@ type ( TaskDesc string `form:"taskDesc"` ModelName string `form:"modelName"` ModelType string `form:"modelType"` - AdapterId string `form:"adapterId"` + AdapterIds []string `form:"adapterIds"` AiClusterIds []string `form:"aiClusterIds,optional"` ResourceType string `form:"resourceType,optional"` ComputeCard string `form:"card,optional"` @@ -76,6 +76,18 @@ type ( } + /******************TextToImage inference*************************/ + TextToImageInferenceReq{ + TaskName string `form:"taskName"` + TaskDesc string `form:"taskDesc"` + ModelName string `form:"modelName"` + ModelType string `form:"modelType"` + AiClusterIds []string `form:"aiClusterIds"` + } + TextToImageInferenceResp{ + Result []byte + } + /******************Deploy instance*************************/ DeployInstanceListReq{ PageInfo @@ -146,6 +158,7 @@ type ( } GetRunningInstanceReq { + AdapterIds []string `form:"adapterIds"` ModelType string `path:"modelType"` ModelName string `path:"modelName"` } diff --git a/internal/handler/inference/getrunninginstancebymodelhandler.go b/internal/handler/inference/getrunninginstancebymodelhandler.go index 534796c9..8075a73e 100644 --- a/internal/handler/inference/getrunninginstancebymodelhandler.go +++ b/internal/handler/inference/getrunninginstancebymodelhandler.go @@ -1,28 +1,25 @@ package inference import ( + "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result" "net/http" "github.com/zeromicro/go-zero/rest/httpx" - "gitlink.org.cn/tzwang/pcm-coordinator/internal/logic/inference" - "gitlink.org.cn/tzwang/pcm-coordinator/internal/svc" - "gitlink.org.cn/tzwang/pcm-coordinator/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/logic/inference" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" ) func GetRunningInstanceByModelHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.GetRunningInstanceReq if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + result.ParamErrorResult(r, w, err) return } l := inference.NewGetRunningInstanceByModelLogic(r.Context(), svcCtx) resp, err := l.GetRunningInstanceByModel(&req) - if err != nil { - httpx.ErrorCtx(r.Context(), w, err) - } else { - httpx.OkJsonCtx(r.Context(), w, resp) - } + result.HttpResult(r, w, resp, err) } } diff --git a/internal/logic/inference/deployinstancelistlogic.go b/internal/logic/inference/deployinstancelistlogic.go index cf937781..4e4c7c37 100644 --- a/internal/logic/inference/deployinstancelistlogic.go +++ b/internal/logic/inference/deployinstancelistlogic.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/zeromicro/go-zero/core/logx" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/common" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/updater" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" @@ -30,14 +31,34 @@ func (l *DeployInstanceListLogic) DeployInstanceList(req *types.DeployInstanceLi offset := req.PageSize * (req.PageNum - 1) resp = &types.DeployInstanceListResp{} - var list []*models.AiInferDeployInstance - - tx := l.svcCtx.DbEngin.Raw("select * from ai_infer_deploy_instance").Scan(&list) + var tasklist []*models.AiDeployInstanceTask + tx := l.svcCtx.DbEngin.Raw("select * from ai_deploy_instance_task").Scan(&tasklist) if tx.Error != nil { logx.Errorf(tx.Error.Error()) return nil, tx.Error } + //count total + var total int64 + err = tx.Count(&total).Error + tx.Limit(limit).Offset(offset) + + if err != nil { + return resp, err + } + + err = tx.Order("create_time desc").Find(&tasklist).Error + if err != nil { + return nil, errors.New(err.Error()) + } + + deployTasks := l.GenerateDeployTasks(tasklist) + slices := make([][]*models.AiInferDeployInstance, len(deployTasks)) + for i := 0; i < len(deployTasks); i++ { + slices[i] = deployTasks[i].Instances + } + list := common.ConcatMultipleSlices(slices) + if len(list) == 0 { return } @@ -55,23 +76,35 @@ func (l *DeployInstanceListLogic) DeployInstanceList(req *types.DeployInstanceLi go updater.UpdateDeployInstanceStatus(l.svcCtx, ins, true) go updater.UpdateDeployTaskStatus(l.svcCtx) - //count total - var total int64 - err = tx.Count(&total).Error - tx.Limit(limit).Offset(offset) - if err != nil { - return resp, err - } - - err = tx.Order("create_time desc").Find(&list).Error - if err != nil { - return nil, errors.New(err.Error()) - } - resp.List = &list + resp.List = &deployTasks resp.PageSize = req.PageSize resp.PageNum = req.PageNum resp.Total = total return } + +func (l *DeployInstanceListLogic) GenerateDeployTasks(tasklist []*models.AiDeployInstanceTask) []*DeployTask { + var tasks []*DeployTask + for _, t := range tasklist { + list, err := l.svcCtx.Scheduler.AiStorages.GetInstanceListByDeployTaskId(t.Id) + if err != nil { + logx.Errorf("db GetInstanceListByDeployTaskId error") + continue + } + deployTask := &DeployTask{ + Id: t.Id, + Name: t.Name, + Instances: list, + } + tasks = append(tasks, deployTask) + } + return tasks +} + +type DeployTask struct { + Id int64 `json:"id,string"` + Name string `json:"name,string"` + Instances []*models.AiInferDeployInstance `json:"instances,string"` +} diff --git a/internal/logic/inference/getrunninginstancebymodellogic.go b/internal/logic/inference/getrunninginstancebymodellogic.go index 7ffef4c7..1e598dad 100644 --- a/internal/logic/inference/getrunninginstancebymodellogic.go +++ b/internal/logic/inference/getrunninginstancebymodellogic.go @@ -3,8 +3,8 @@ package inference import ( "context" - "gitlink.org.cn/tzwang/pcm-coordinator/internal/svc" - "gitlink.org.cn/tzwang/pcm-coordinator/internal/types" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc" + "gitlink.org.cn/JointCloud/pcm-coordinator/internal/types" "github.com/zeromicro/go-zero/core/logx" ) @@ -24,7 +24,7 @@ func NewGetRunningInstanceByModelLogic(ctx context.Context, svcCtx *svc.ServiceC } func (l *GetRunningInstanceByModelLogic) GetRunningInstanceByModel(req *types.GetRunningInstanceReq) (resp *types.GetRunningInstanceResp, err error) { - // todo: add your logic here and delete this line + resp = &types.GetRunningInstanceResp{} return } diff --git a/internal/scheduler/common/common.go b/internal/scheduler/common/common.go index 68bfaa32..3447b5f5 100644 --- a/internal/scheduler/common/common.go +++ b/internal/scheduler/common/common.go @@ -97,3 +97,21 @@ func Contains(s []string, e string) bool { } return false } + +func ConcatMultipleSlices[T any](slices [][]T) []T { + var totalLen int + + for _, s := range slices { + totalLen += len(s) + } + + result := make([]T, totalLen) + + var i int + + for _, s := range slices { + i += copy(result[i:], s) + } + + return result +} diff --git a/internal/scheduler/database/aiStorage.go b/internal/scheduler/database/aiStorage.go index 70ac99c2..58fbc554 100644 --- a/internal/scheduler/database/aiStorage.go +++ b/internal/scheduler/database/aiStorage.go @@ -485,6 +485,16 @@ func (s *AiStorage) GetInferDeployInstanceList() ([]*models.AiInferDeployInstanc return list, nil } +func (s *AiStorage) GetDeployTaskList() ([]*models.AiDeployInstanceTask, error) { + var list []*models.AiDeployInstanceTask + tx := s.DbEngin.Raw("select * from ai_deploy_instance_task").Scan(&list) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return nil, tx.Error + } + return list, nil +} + func (s *AiStorage) GetInferDeployInstanceTotalNum() (int32, error) { var total int32 tx := s.DbEngin.Raw("select count(*) from ai_infer_deploy_instance").Scan(&total) @@ -563,3 +573,13 @@ func (s *AiStorage) SaveInferDeployTask(taskName string, modelName string, model } return taskModel.Id, nil } + +func (s *AiStorage) GetRunningDeployInstanceByModelNameAndAdapterId(modelType string, modelName string, adapterId string) ([]*models.AiInferDeployInstance, error) { + var list []*models.AiInferDeployInstance + tx := s.DbEngin.Raw("select * from ai_infer_deploy_instance where `model_type` = ? and `model_name` = ? and `adapter_id` = ? and `status` = 'Running'", modelType, modelName, adapterId).Scan(&list) + if tx.Error != nil { + logx.Errorf(tx.Error.Error()) + return nil, tx.Error + } + return list, nil +} diff --git a/internal/storeLink/storeLink.go b/internal/storeLink/storeLink.go index a59591ca..8c7dc042 100644 --- a/internal/storeLink/storeLink.go +++ b/internal/storeLink/storeLink.go @@ -82,6 +82,7 @@ var ( "image_classification": {"imagenet_resnet50"}, "text_to_text": {"chatGLM_6B"}, "image_to_text": {"blip-image-captioning-base"}, + "text_to_image": {"stable-diffusion-xl-base-1.0"}, } AITYPE = map[string]string{ "1": OCTOPUS,