Merge pull request 'added textinfer api' (#244) from tzwang/pcm-coordinator:master into master

Former-commit-id: 108b1a8c7895a4423f6bd2cf6af9888ba860c944
This commit is contained in:
tzwang 2024-06-25 18:27:38 +08:00
commit 2d8ce51ad6
13 changed files with 216 additions and 19 deletions

View File

@ -63,7 +63,17 @@ type (
clusterName string `json:"clusterName"` clusterName string `json:"clusterName"`
} }
/******************TextToText inference*************************/
TextToTextInferenceReq{
TaskName string `form:"taskName"`
TaskDesc string `form:"taskDesc"`
ModelName string `form:"modelName"`
ModelType string `form:"modelType"`
AdapterId string `form:"adapterId"`
AiClusterIds []string `form:"aiClusterIds"`
}
TextToTextInferenceResp{
}
) )

View File

@ -907,6 +907,9 @@ service pcm {
group: inference group: inference
) )
service pcm { service pcm {
@handler TextToTextInferenceHandler
post /inference/text (TextToTextInferenceReq) returns (TextToTextInferenceResp)
@handler ImageInferenceHandler @handler ImageInferenceHandler
post /inference/images (ImageInferenceReq) returns (ImageInferenceResp) post /inference/images (ImageInferenceReq) returns (ImageInferenceResp)

View File

@ -0,0 +1,25 @@
package inference
import (
"github.com/zeromicro/go-zero/rest/httpx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
"net/http"
)
func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.TextToTextInferenceReq
if err := httpx.Parse(r, &req); err != nil {
result.ParamErrorResult(r, w, err)
return
}
l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx)
resp, err := l.TextToTextInference(&req)
result.HttpResult(r, w, resp, err)
}
}

View File

@ -1138,6 +1138,11 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
server.AddRoutes( server.AddRoutes(
[]rest.Route{ []rest.Route{
{
Method: http.MethodPost,
Path: "/inference/text",
Handler: inference.TextToTextInferenceHandler(serverCtx),
},
{ {
Method: http.MethodPost, Method: http.MethodPost,
Path: "/inference/images", Path: "/inference/images",

View File

@ -133,14 +133,14 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
var wg sync.WaitGroup var wg sync.WaitGroup
var cluster_ch = make(chan struct { var cluster_ch = make(chan struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
}, len(clusters)) }, len(clusters))
var cs []struct { var cs []struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
@ -182,7 +182,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
wg.Add(1) wg.Add(1)
c := cluster c := cluster
go func() { go func() {
imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt)
if err != nil { if err != nil {
wg.Done() wg.Done()
return return
@ -190,7 +190,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
s := struct { s := struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
@ -373,7 +373,7 @@ func sendInferReq(images []struct {
imageResult *types.ImageResult imageResult *types.ImageResult
file multipart.File file multipart.File
}, cluster struct { }, cluster struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
@ -384,7 +384,7 @@ func sendInferReq(images []struct {
imageResult *types.ImageResult imageResult *types.ImageResult
file multipart.File file multipart.File
}, c struct { }, c struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
@ -494,7 +494,7 @@ type Res struct {
} }
func contains(cs []struct { func contains(cs []struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32

View File

@ -0,0 +1,140 @@
package inference
import (
"context"
"errors"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"strconv"
"sync"
"time"
)
type TextToTextInferenceLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TextToTextInferenceLogic {
return &TextToTextInferenceLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) {
resp = &types.TextToTextInferenceResp{}
opt := &option.InferOption{
TaskName: req.TaskName,
TaskDesc: req.TaskDesc,
AdapterId: req.AdapterId,
AiClusterIds: req.AiClusterIds,
ModelName: req.ModelName,
ModelType: req.ModelType,
}
_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
if !ok {
return nil, errors.New("AdapterId does not exist")
}
//save task
var synergystatus int64
var strategyCode int64
adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId)
if err != nil {
return nil, err
}
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12")
if err != nil {
return nil, err
}
var wg sync.WaitGroup
var cluster_ch = make(chan struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}, len(opt.AiClusterIds))
var cs []struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}
collectorMap := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
//save taskai
for _, clusterId := range opt.AiClusterIds {
wg.Add(1)
go func(cId string) {
urls, err := collectorMap[cId].GetInferUrl(l.ctx, opt)
if err != nil {
wg.Done()
return
}
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId)
s := struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}{
urls: urls,
clusterId: cId,
clusterName: clusterName,
}
cluster_ch <- s
wg.Done()
return
}(clusterId)
}
wg.Wait()
close(cluster_ch)
for s := range cluster_ch {
cs = append(cs, s)
}
for _, c := range cs {
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId)
err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
if err != nil {
return nil, err
}
}
var aiTaskList []*models.TaskAi
tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
if tx.Error != nil {
return nil, tx.Error
}
for i, t := range aiTaskList {
if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId {
t.Status = constants.Completed
t.EndTime = time.Now().Format(time.RFC3339)
url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat"
t.InferUrl = url
err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
}
l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成")
return resp, nil
}

View File

@ -15,10 +15,10 @@ type AiCollector interface {
UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error
GetComputeCards(ctx context.Context) ([]string, error) GetComputeCards(ctx context.Context) ([]string, error)
GetUserBalance(ctx context.Context) (float64, error) GetUserBalance(ctx context.Context) (float64, error)
GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*ImageInferUrl, error) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error)
} }
type ImageInferUrl struct { type InferUrl struct {
Url string Url string
Card string Card string
} }

View File

@ -378,8 +378,8 @@ func (m *ModelArtsLink) generateAlgorithmId(ctx context.Context, option *option.
return errors.New("failed to get AlgorithmId") return errors.New("failed to get AlgorithmId")
} }
func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (m *ModelArtsLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.InferUrl
urlReq := &modelartsclient.ImageReasoningUrlReq{ urlReq := &modelartsclient.ImageReasoningUrlReq{
ModelName: option.ModelName, ModelName: option.ModelName,
Type: option.ModelType, Type: option.ModelType,
@ -389,7 +389,7 @@ func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.Inf
if err != nil { if err != nil {
return nil, err return nil, err
} }
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.InferUrl{
Url: urlResp.Url, Url: urlResp.Url,
Card: "npu", Card: "npu",
} }

View File

@ -871,7 +871,7 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec
return errors.New("set ResourceId error") return errors.New("set ResourceId error")
} }
func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (o *OctopusLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
req := &octopus.GetNotebookListReq{ req := &octopus.GetNotebookListReq{
Platform: o.platform, Platform: o.platform,
PageIndex: o.pageIndex, PageIndex: o.pageIndex,
@ -882,12 +882,12 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer
return nil, err return nil, err
} }
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.InferUrl
for _, notebook := range list.Payload.GetNotebooks() { for _, notebook := range list.Payload.GetNotebooks() {
if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" { if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" {
url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1) url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1)
names := strings.Split(notebook.AlgorithmName, UNDERSCORE) names := strings.Split(notebook.AlgorithmName, UNDERSCORE)
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.InferUrl{
Url: DOMAIN + url + FORWARD_SLASH + "image", Url: DOMAIN + url + FORWARD_SLASH + "image",
Card: names[2], Card: names[2],
} }

View File

@ -730,8 +730,8 @@ func (s *ShuguangAi) generateParams(option *option.AiOption) error {
return errors.New("failed to set params") return errors.New("failed to set params")
} }
func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (s *ShuguangAi) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.InferUrl
urlReq := &hpcAC.GetInferUrlReq{ urlReq := &hpcAC.GetInferUrlReq{
ModelName: option.ModelName, ModelName: option.ModelName,
@ -743,7 +743,7 @@ func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferO
if err != nil { if err != nil {
return nil, err return nil, err
} }
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.InferUrl{
Url: urlResp.Url, Url: urlResp.Url,
Card: "dcu", Card: "dcu",
} }

View File

@ -78,6 +78,7 @@ var (
} }
ModelTypeMap = map[string][]string{ ModelTypeMap = map[string][]string{
"image_recognition": {"imagenet_resnet50"}, "image_recognition": {"imagenet_resnet50"},
"text_to_text": {"chatGLM-6B"},
} }
AITYPE = map[string]string{ AITYPE = map[string]string{
"1": OCTOPUS, "1": OCTOPUS,

View File

@ -5941,3 +5941,15 @@ type InferenceResult struct {
Card string `json:"card"` Card string `json:"card"`
ClusterName string `json:"clusterName"` ClusterName string `json:"clusterName"`
} }
type TextToTextInferenceReq struct {
TaskName string `form:"taskName"`
TaskDesc string `form:"taskDesc"`
ModelName string `form:"modelName"`
ModelType string `form:"modelType"`
AdapterId string `form:"adapterId"`
AiClusterIds []string `form:"aiClusterIds"`
}
type TextToTextInferenceResp struct {
}

View File

@ -54,6 +54,7 @@ type (
TaskType string `db:"task_type"` TaskType string `db:"task_type"`
DeletedAt *time.Time `db:"deleted_at"` DeletedAt *time.Time `db:"deleted_at"`
Card string `db:"card"` Card string `db:"card"`
InferUrl string `db:"infer_url"`
} }
) )