added textinfer api

Former-commit-id: bfdce9025139d71bc3178b039756d579be1b450c
This commit is contained in:
tzwang 2024-06-25 18:19:18 +08:00
parent 60012ab0fb
commit e6b9d3d23b
9 changed files with 140 additions and 31 deletions

View File

@ -1,28 +1,25 @@
package inference package inference
import ( import (
"net/http"
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
"net/http"
) )
func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.TextToTextInferenceReq var req types.TextToTextInferenceReq
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err) result.ParamErrorResult(r, w, err)
return return
} }
l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx) l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx)
resp, err := l.TextToTextInference(&req) resp, err := l.TextToTextInference(&req)
if err != nil { result.HttpResult(r, w, resp, err)
httpx.ErrorCtx(r.Context(), w, err)
} else {
httpx.OkJsonCtx(r.Context(), w, resp)
}
} }
} }

View File

@ -133,14 +133,14 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
var wg sync.WaitGroup var wg sync.WaitGroup
var cluster_ch = make(chan struct { var cluster_ch = make(chan struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
}, len(clusters)) }, len(clusters))
var cs []struct { var cs []struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
@ -182,7 +182,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
wg.Add(1) wg.Add(1)
c := cluster c := cluster
go func() { go func() {
imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt) imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt)
if err != nil { if err != nil {
wg.Done() wg.Done()
return return
@ -190,7 +190,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId) clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
s := struct { s := struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
@ -373,7 +373,7 @@ func sendInferReq(images []struct {
imageResult *types.ImageResult imageResult *types.ImageResult
file multipart.File file multipart.File
}, cluster struct { }, cluster struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
@ -384,7 +384,7 @@ func sendInferReq(images []struct {
imageResult *types.ImageResult imageResult *types.ImageResult
file multipart.File file multipart.File
}, c struct { }, c struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32
@ -494,7 +494,7 @@ type Res struct {
} }
func contains(cs []struct { func contains(cs []struct {
urls []*collector.ImageInferUrl urls []*collector.InferUrl
clusterId string clusterId string
clusterName string clusterName string
imageNum int32 imageNum int32

View File

@ -2,11 +2,18 @@ package inference
import ( import (
"context" "context"
"errors"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"strconv"
"sync"
"time"
) )
type TextToTextInferenceLogic struct { type TextToTextInferenceLogic struct {
@ -24,7 +31,110 @@ func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext
} }
func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) { func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) {
// todo: add your logic here and delete this line resp = &types.TextToTextInferenceResp{}
opt := &option.InferOption{
TaskName: req.TaskName,
TaskDesc: req.TaskDesc,
AdapterId: req.AdapterId,
AiClusterIds: req.AiClusterIds,
ModelName: req.ModelName,
ModelType: req.ModelType,
}
_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
if !ok {
return nil, errors.New("AdapterId does not exist")
}
//save task
var synergystatus int64
var strategyCode int64
adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId)
if err != nil {
return nil, err
}
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12")
if err != nil {
return nil, err
}
var wg sync.WaitGroup
var cluster_ch = make(chan struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}, len(opt.AiClusterIds))
var cs []struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}
collectorMap := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
//save taskai
for _, clusterId := range opt.AiClusterIds {
wg.Add(1)
go func(cId string) {
urls, err := collectorMap[cId].GetInferUrl(l.ctx, opt)
if err != nil {
wg.Done()
return return
} }
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId)
s := struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}{
urls: urls,
clusterId: cId,
clusterName: clusterName,
}
cluster_ch <- s
wg.Done()
return
}(clusterId)
}
wg.Wait()
close(cluster_ch)
for s := range cluster_ch {
cs = append(cs, s)
}
for _, c := range cs {
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId)
err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
if err != nil {
return nil, err
}
}
var aiTaskList []*models.TaskAi
tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
if tx.Error != nil {
return nil, tx.Error
}
for i, t := range aiTaskList {
if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId {
t.Status = constants.Completed
t.EndTime = time.Now().Format(time.RFC3339)
url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat"
t.InferUrl = url
err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
}
l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成")
return resp, nil
}

View File

@ -15,10 +15,10 @@ type AiCollector interface {
UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error
GetComputeCards(ctx context.Context) ([]string, error) GetComputeCards(ctx context.Context) ([]string, error)
GetUserBalance(ctx context.Context) (float64, error) GetUserBalance(ctx context.Context) (float64, error)
GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*ImageInferUrl, error) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error)
} }
type ImageInferUrl struct { type InferUrl struct {
Url string Url string
Card string Card string
} }

View File

@ -378,8 +378,8 @@ func (m *ModelArtsLink) generateAlgorithmId(ctx context.Context, option *option.
return errors.New("failed to get AlgorithmId") return errors.New("failed to get AlgorithmId")
} }
func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (m *ModelArtsLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.InferUrl
urlReq := &modelartsclient.ImageReasoningUrlReq{ urlReq := &modelartsclient.ImageReasoningUrlReq{
ModelName: option.ModelName, ModelName: option.ModelName,
Type: option.ModelType, Type: option.ModelType,
@ -389,7 +389,7 @@ func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.Inf
if err != nil { if err != nil {
return nil, err return nil, err
} }
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.InferUrl{
Url: urlResp.Url, Url: urlResp.Url,
Card: "npu", Card: "npu",
} }

View File

@ -871,7 +871,7 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec
return errors.New("set ResourceId error") return errors.New("set ResourceId error")
} }
func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (o *OctopusLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
req := &octopus.GetNotebookListReq{ req := &octopus.GetNotebookListReq{
Platform: o.platform, Platform: o.platform,
PageIndex: o.pageIndex, PageIndex: o.pageIndex,
@ -882,12 +882,12 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer
return nil, err return nil, err
} }
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.InferUrl
for _, notebook := range list.Payload.GetNotebooks() { for _, notebook := range list.Payload.GetNotebooks() {
if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" { if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" {
url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1) url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1)
names := strings.Split(notebook.AlgorithmName, UNDERSCORE) names := strings.Split(notebook.AlgorithmName, UNDERSCORE)
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.InferUrl{
Url: DOMAIN + url + FORWARD_SLASH + "image", Url: DOMAIN + url + FORWARD_SLASH + "image",
Card: names[2], Card: names[2],
} }

View File

@ -730,8 +730,8 @@ func (s *ShuguangAi) generateParams(option *option.AiOption) error {
return errors.New("failed to set params") return errors.New("failed to set params")
} }
func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (s *ShuguangAi) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.InferUrl
urlReq := &hpcAC.GetInferUrlReq{ urlReq := &hpcAC.GetInferUrlReq{
ModelName: option.ModelName, ModelName: option.ModelName,
@ -743,7 +743,7 @@ func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferO
if err != nil { if err != nil {
return nil, err return nil, err
} }
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.InferUrl{
Url: urlResp.Url, Url: urlResp.Url,
Card: "dcu", Card: "dcu",
} }

View File

@ -78,6 +78,7 @@ var (
} }
ModelTypeMap = map[string][]string{ ModelTypeMap = map[string][]string{
"image_recognition": {"imagenet_resnet50"}, "image_recognition": {"imagenet_resnet50"},
"text_to_text": {"chatGLM-6B"},
} }
AITYPE = map[string]string{ AITYPE = map[string]string{
"1": OCTOPUS, "1": OCTOPUS,

View File

@ -54,6 +54,7 @@ type (
TaskType string `db:"task_type"` TaskType string `db:"task_type"`
DeletedAt *time.Time `db:"deleted_at"` DeletedAt *time.Time `db:"deleted_at"`
Card string `db:"card"` Card string `db:"card"`
InferUrl string `db:"infer_url"`
} }
) )