added textinfer api
Former-commit-id: bfdce9025139d71bc3178b039756d579be1b450c
This commit is contained in:
parent
60012ab0fb
commit
e6b9d3d23b
|
@ -1,28 +1,25 @@
|
||||||
package inference
|
package inference
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/rest/httpx"
|
"github.com/zeromicro/go-zero/rest/httpx"
|
||||||
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference"
|
||||||
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
|
||||||
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
|
||||||
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
|
||||||
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
var req types.TextToTextInferenceReq
|
var req types.TextToTextInferenceReq
|
||||||
if err := httpx.Parse(r, &req); err != nil {
|
if err := httpx.Parse(r, &req); err != nil {
|
||||||
httpx.ErrorCtx(r.Context(), w, err)
|
result.ParamErrorResult(r, w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx)
|
l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx)
|
||||||
resp, err := l.TextToTextInference(&req)
|
resp, err := l.TextToTextInference(&req)
|
||||||
if err != nil {
|
result.HttpResult(r, w, resp, err)
|
||||||
httpx.ErrorCtx(r.Context(), w, err)
|
|
||||||
} else {
|
|
||||||
httpx.OkJsonCtx(r.Context(), w, resp)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -133,14 +133,14 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
var cluster_ch = make(chan struct {
|
var cluster_ch = make(chan struct {
|
||||||
urls []*collector.ImageInferUrl
|
urls []*collector.InferUrl
|
||||||
clusterId string
|
clusterId string
|
||||||
clusterName string
|
clusterName string
|
||||||
imageNum int32
|
imageNum int32
|
||||||
}, len(clusters))
|
}, len(clusters))
|
||||||
|
|
||||||
var cs []struct {
|
var cs []struct {
|
||||||
urls []*collector.ImageInferUrl
|
urls []*collector.InferUrl
|
||||||
clusterId string
|
clusterId string
|
||||||
clusterName string
|
clusterName string
|
||||||
imageNum int32
|
imageNum int32
|
||||||
|
@ -182,7 +182,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
c := cluster
|
c := cluster
|
||||||
go func() {
|
go func() {
|
||||||
imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt)
|
imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
return
|
return
|
||||||
|
@ -190,7 +190,7 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
|
||||||
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
|
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
|
||||||
|
|
||||||
s := struct {
|
s := struct {
|
||||||
urls []*collector.ImageInferUrl
|
urls []*collector.InferUrl
|
||||||
clusterId string
|
clusterId string
|
||||||
clusterName string
|
clusterName string
|
||||||
imageNum int32
|
imageNum int32
|
||||||
|
@ -373,7 +373,7 @@ func sendInferReq(images []struct {
|
||||||
imageResult *types.ImageResult
|
imageResult *types.ImageResult
|
||||||
file multipart.File
|
file multipart.File
|
||||||
}, cluster struct {
|
}, cluster struct {
|
||||||
urls []*collector.ImageInferUrl
|
urls []*collector.InferUrl
|
||||||
clusterId string
|
clusterId string
|
||||||
clusterName string
|
clusterName string
|
||||||
imageNum int32
|
imageNum int32
|
||||||
|
@ -384,7 +384,7 @@ func sendInferReq(images []struct {
|
||||||
imageResult *types.ImageResult
|
imageResult *types.ImageResult
|
||||||
file multipart.File
|
file multipart.File
|
||||||
}, c struct {
|
}, c struct {
|
||||||
urls []*collector.ImageInferUrl
|
urls []*collector.InferUrl
|
||||||
clusterId string
|
clusterId string
|
||||||
clusterName string
|
clusterName string
|
||||||
imageNum int32
|
imageNum int32
|
||||||
|
@ -494,7 +494,7 @@ type Res struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func contains(cs []struct {
|
func contains(cs []struct {
|
||||||
urls []*collector.ImageInferUrl
|
urls []*collector.InferUrl
|
||||||
clusterId string
|
clusterId string
|
||||||
clusterName string
|
clusterName string
|
||||||
imageNum int32
|
imageNum int32
|
||||||
|
|
|
@ -2,11 +2,18 @@ package inference
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
|
||||||
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector"
|
||||||
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink"
|
||||||
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
|
||||||
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
|
||||||
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TextToTextInferenceLogic struct {
|
type TextToTextInferenceLogic struct {
|
||||||
|
@ -24,7 +31,110 @@ func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) {
|
func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) {
|
||||||
// todo: add your logic here and delete this line
|
resp = &types.TextToTextInferenceResp{}
|
||||||
|
opt := &option.InferOption{
|
||||||
|
TaskName: req.TaskName,
|
||||||
|
TaskDesc: req.TaskDesc,
|
||||||
|
AdapterId: req.AdapterId,
|
||||||
|
AiClusterIds: req.AiClusterIds,
|
||||||
|
ModelName: req.ModelName,
|
||||||
|
ModelType: req.ModelType,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("AdapterId does not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
//save task
|
||||||
|
var synergystatus int64
|
||||||
|
var strategyCode int64
|
||||||
|
adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var cluster_ch = make(chan struct {
|
||||||
|
urls []*collector.InferUrl
|
||||||
|
clusterId string
|
||||||
|
clusterName string
|
||||||
|
}, len(opt.AiClusterIds))
|
||||||
|
|
||||||
|
var cs []struct {
|
||||||
|
urls []*collector.InferUrl
|
||||||
|
clusterId string
|
||||||
|
clusterName string
|
||||||
|
}
|
||||||
|
collectorMap := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
|
||||||
|
|
||||||
|
//save taskai
|
||||||
|
for _, clusterId := range opt.AiClusterIds {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(cId string) {
|
||||||
|
urls, err := collectorMap[cId].GetInferUrl(l.ctx, opt)
|
||||||
|
if err != nil {
|
||||||
|
wg.Done()
|
||||||
return
|
return
|
||||||
|
}
|
||||||
|
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId)
|
||||||
|
|
||||||
|
s := struct {
|
||||||
|
urls []*collector.InferUrl
|
||||||
|
clusterId string
|
||||||
|
clusterName string
|
||||||
|
}{
|
||||||
|
urls: urls,
|
||||||
|
clusterId: cId,
|
||||||
|
clusterName: clusterName,
|
||||||
|
}
|
||||||
|
|
||||||
|
cluster_ch <- s
|
||||||
|
wg.Done()
|
||||||
|
return
|
||||||
|
}(clusterId)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(cluster_ch)
|
||||||
|
|
||||||
|
for s := range cluster_ch {
|
||||||
|
cs = append(cs, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cs {
|
||||||
|
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId)
|
||||||
|
err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var aiTaskList []*models.TaskAi
|
||||||
|
tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
|
||||||
|
if tx.Error != nil {
|
||||||
|
return nil, tx.Error
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, t := range aiTaskList {
|
||||||
|
if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId {
|
||||||
|
t.Status = constants.Completed
|
||||||
|
t.EndTime = time.Now().Format(time.RFC3339)
|
||||||
|
url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat"
|
||||||
|
t.InferUrl = url
|
||||||
|
err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
|
||||||
|
if err != nil {
|
||||||
|
logx.Errorf(tx.Error.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成")
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,10 +15,10 @@ type AiCollector interface {
|
||||||
UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error
|
UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error
|
||||||
GetComputeCards(ctx context.Context) ([]string, error)
|
GetComputeCards(ctx context.Context) ([]string, error)
|
||||||
GetUserBalance(ctx context.Context) (float64, error)
|
GetUserBalance(ctx context.Context) (float64, error)
|
||||||
GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*ImageInferUrl, error)
|
GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ImageInferUrl struct {
|
type InferUrl struct {
|
||||||
Url string
|
Url string
|
||||||
Card string
|
Card string
|
||||||
}
|
}
|
||||||
|
|
|
@ -378,8 +378,8 @@ func (m *ModelArtsLink) generateAlgorithmId(ctx context.Context, option *option.
|
||||||
return errors.New("failed to get AlgorithmId")
|
return errors.New("failed to get AlgorithmId")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) {
|
func (m *ModelArtsLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
|
||||||
var imageUrls []*collector.ImageInferUrl
|
var imageUrls []*collector.InferUrl
|
||||||
urlReq := &modelartsclient.ImageReasoningUrlReq{
|
urlReq := &modelartsclient.ImageReasoningUrlReq{
|
||||||
ModelName: option.ModelName,
|
ModelName: option.ModelName,
|
||||||
Type: option.ModelType,
|
Type: option.ModelType,
|
||||||
|
@ -389,7 +389,7 @@ func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.Inf
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
imageUrl := &collector.ImageInferUrl{
|
imageUrl := &collector.InferUrl{
|
||||||
Url: urlResp.Url,
|
Url: urlResp.Url,
|
||||||
Card: "npu",
|
Card: "npu",
|
||||||
}
|
}
|
||||||
|
|
|
@ -871,7 +871,7 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec
|
||||||
return errors.New("set ResourceId error")
|
return errors.New("set ResourceId error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) {
|
func (o *OctopusLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
|
||||||
req := &octopus.GetNotebookListReq{
|
req := &octopus.GetNotebookListReq{
|
||||||
Platform: o.platform,
|
Platform: o.platform,
|
||||||
PageIndex: o.pageIndex,
|
PageIndex: o.pageIndex,
|
||||||
|
@ -882,12 +882,12 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var imageUrls []*collector.ImageInferUrl
|
var imageUrls []*collector.InferUrl
|
||||||
for _, notebook := range list.Payload.GetNotebooks() {
|
for _, notebook := range list.Payload.GetNotebooks() {
|
||||||
if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" {
|
if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" {
|
||||||
url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1)
|
url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1)
|
||||||
names := strings.Split(notebook.AlgorithmName, UNDERSCORE)
|
names := strings.Split(notebook.AlgorithmName, UNDERSCORE)
|
||||||
imageUrl := &collector.ImageInferUrl{
|
imageUrl := &collector.InferUrl{
|
||||||
Url: DOMAIN + url + FORWARD_SLASH + "image",
|
Url: DOMAIN + url + FORWARD_SLASH + "image",
|
||||||
Card: names[2],
|
Card: names[2],
|
||||||
}
|
}
|
||||||
|
|
|
@ -730,8 +730,8 @@ func (s *ShuguangAi) generateParams(option *option.AiOption) error {
|
||||||
return errors.New("failed to set params")
|
return errors.New("failed to set params")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) {
|
func (s *ShuguangAi) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
|
||||||
var imageUrls []*collector.ImageInferUrl
|
var imageUrls []*collector.InferUrl
|
||||||
|
|
||||||
urlReq := &hpcAC.GetInferUrlReq{
|
urlReq := &hpcAC.GetInferUrlReq{
|
||||||
ModelName: option.ModelName,
|
ModelName: option.ModelName,
|
||||||
|
@ -743,7 +743,7 @@ func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferO
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
imageUrl := &collector.ImageInferUrl{
|
imageUrl := &collector.InferUrl{
|
||||||
Url: urlResp.Url,
|
Url: urlResp.Url,
|
||||||
Card: "dcu",
|
Card: "dcu",
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,6 +78,7 @@ var (
|
||||||
}
|
}
|
||||||
ModelTypeMap = map[string][]string{
|
ModelTypeMap = map[string][]string{
|
||||||
"image_recognition": {"imagenet_resnet50"},
|
"image_recognition": {"imagenet_resnet50"},
|
||||||
|
"text_to_text": {"chatGLM-6B"},
|
||||||
}
|
}
|
||||||
AITYPE = map[string]string{
|
AITYPE = map[string]string{
|
||||||
"1": OCTOPUS,
|
"1": OCTOPUS,
|
||||||
|
|
|
@ -54,6 +54,7 @@ type (
|
||||||
TaskType string `db:"task_type"`
|
TaskType string `db:"task_type"`
|
||||||
DeletedAt *time.Time `db:"deleted_at"`
|
DeletedAt *time.Time `db:"deleted_at"`
|
||||||
Card string `db:"card"`
|
Card string `db:"card"`
|
||||||
|
InferUrl string `db:"infer_url"`
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue