updated inference api logics

Former-commit-id: 64f183b77efd974618fc6306aa78db6dd2d1f893
This commit is contained in:
tzwang 2024-08-27 15:30:26 +08:00
parent 5ec517690a
commit 1103e589f4
6 changed files with 36 additions and 20 deletions

View File

@ -1,6 +1,7 @@
package inference
import (
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
"net/http"
"github.com/zeromicro/go-zero/rest/httpx"
@ -13,16 +14,12 @@ func GetDeployTasksByTypeHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.GetDeployTasksByTypeReq
if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err)
result.ParamErrorResult(r, w, err)
return
}
l := inference.NewGetDeployTasksByTypeLogic(r.Context(), svcCtx)
resp, err := l.GetDeployTasksByType(&req)
if err != nil {
httpx.ErrorCtx(r.Context(), w, err)
} else {
httpx.OkJsonCtx(r.Context(), w, resp)
}
result.HttpResult(r, w, resp, err)
}
}

View File

@ -1,6 +1,7 @@
package inference
import (
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
"net/http"
"github.com/zeromicro/go-zero/rest/httpx"
@ -13,16 +14,13 @@ func GetRunningInstanceByTypeHandler(svcCtx *svc.ServiceContext) http.HandlerFun
return func(w http.ResponseWriter, r *http.Request) {
var req types.GetRunningInstanceReq
if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err)
result.ParamErrorResult(r, w, err)
return
}
l := inference.NewGetRunningInstanceByTypeLogic(r.Context(), svcCtx)
resp, err := l.GetRunningInstanceByType(&req)
if err != nil {
httpx.ErrorCtx(r.Context(), w, err)
} else {
httpx.OkJsonCtx(r.Context(), w, resp)
}
result.HttpResult(r, w, resp, err)
}
}

View File

@ -2,6 +2,7 @@ package inference
import (
"context"
"github.com/pkg/errors"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
@ -24,7 +25,17 @@ func NewGetDeployTasksByTypeLogic(ctx context.Context, svcCtx *svc.ServiceContex
}
func (l *GetDeployTasksByTypeLogic) GetDeployTasksByType(req *types.GetDeployTasksByTypeReq) (resp *types.GetDeployTasksByTypeResp, err error) {
// todo: add your logic here and delete this line
resp = &types.GetDeployTasksByTypeResp{}
return
list, err := l.svcCtx.Scheduler.AiStorages.GetDeployTaskListByType(req.ModelType)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, errors.New("实列不存在")
}
resp.List = list
return resp, nil
}

View File

@ -24,7 +24,7 @@ func NewGetRunningInstanceByTypeLogic(ctx context.Context, svcCtx *svc.ServiceCo
}
func (l *GetRunningInstanceByTypeLogic) GetRunningInstanceByType(req *types.GetRunningInstanceReq) (resp *types.GetRunningInstanceResp, err error) {
// todo: add your logic here and delete this line
resp = &types.GetRunningInstanceResp{}
return
}

View File

@ -35,11 +35,11 @@ func NewImageInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Im
func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInferenceReq) (resp *types.ImageInferenceResp, err error) {
resp = &types.ImageInferenceResp{}
opt := &option.InferOption{
TaskName: req.TaskName,
TaskDesc: req.TaskDesc,
AdapterId: req.AdapterId,
AiClusterIds: req.AiClusterIds,
ModelName: req.ModelName,
TaskName: req.TaskName,
TaskDesc: req.TaskDesc,
//AdapterId: req.AdapterId,
//AiClusterIds: req.AiClusterIds,
//ModelName: req.ModelName,
ModelType: req.ModelType,
Strategy: req.Strategy,
StaticWeightMap: req.StaticWeightMap,

View File

@ -431,6 +431,16 @@ func (s *AiStorage) GetDeployTaskById(id int64) (*models.AiDeployInstanceTask, e
return &task, nil
}
func (s *AiStorage) GetDeployTaskListByType(modelType string) ([]*models.AiDeployInstanceTask, error) {
var tasks []*models.AiDeployInstanceTask
tx := s.DbEngin.Raw("select * from ai_deploy_instance_task where `model_type` = ?", modelType).Scan(&tasks)
if tx.Error != nil {
logx.Errorf(tx.Error.Error())
return nil, tx.Error
}
return tasks, nil
}
func (s *AiStorage) GetAllDeployTasks() ([]*models.AiDeployInstanceTask, error) {
var tasks []*models.AiDeployInstanceTask
tx := s.DbEngin.Raw("select * from ai_deploy_instance_task").Scan(&tasks)