Former-commit-id: a6b8cdc300e3331b600c07a0dac5c5e1d4663e5c
This commit is contained in:
zhangwei 2024-06-27 09:15:03 +08:00
commit 1cd3fbb1e2
27 changed files with 1145 additions and 530 deletions

View File

@ -1,4 +1,4 @@
FROM golang:1.22-alpine3.18 AS builder FROM golang:1.22.4-alpine3.20 AS builder
WORKDIR /app WORKDIR /app
@ -9,7 +9,7 @@ RUN go env -w GO111MODULE=on \
&& go env -w CGO_ENABLED=0 && go env -w CGO_ENABLED=0
RUN go build -o pcm-coordinator-api /app/api/pcm.go RUN go build -o pcm-coordinator-api /app/api/pcm.go
FROM alpine:3.18 FROM alpine:3.20
WORKDIR /app WORKDIR /app

View File

@ -175,7 +175,7 @@ type VmInfo struct {
//DeletedAt string `json:"deletedAt,omitempty"` //DeletedAt string `json:"deletedAt,omitempty"`
VmName string `json:"vmName,omitempty"` VmName string `json:"vmName,omitempty"`
Replicas int64 `json:"replicas,omitempty"` Replicas int64 `json:"replicas,omitempty"`
ServerId string `json:"serverId,omitempty"` //ServerId string `json:"serverId,omitempty"`
} }
type ResourceStats struct { type ResourceStats struct {

View File

@ -1819,3 +1819,11 @@ service AICore-api {
@handler createVisualizationJobHandler @handler createVisualizationJobHandler
post /CreateVisualizationJob (CreateVisualizationJobReq) returns (CreateVisualizationJobResp) post /CreateVisualizationJob (CreateVisualizationJobReq) returns (CreateVisualizationJobResp)
}*/ }*/
type (
ChatReq{
ApiUrl string `json:"apiUrl,optional"`
Method string `json:"method,optional"`
ReqData map[string]interface{} `json:"reqData"`
}
)

View File

@ -404,8 +404,8 @@ type (
TenantId string `json:"tenantId,omitempty" db:"tenant_id"` TenantId string `json:"tenantId,omitempty" db:"tenant_id"`
CreatedTime string `json:"createdTime,omitempty" db:"created_time" gorm:"autoCreateTime"` CreatedTime string `json:"createdTime,omitempty" db:"created_time" gorm:"autoCreateTime"`
UpdatedTime string `json:"updatedTime,omitempty" db:"updated_time"` UpdatedTime string `json:"updatedTime,omitempty" db:"updated_time"`
AdapterTypeDict int `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 AdapterTypeDict string `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值
TaskTypeDict int `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 TaskTypeDict string `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值
} }
) )
@ -1286,5 +1286,14 @@ type (
ClusterName string `json:"clusterName" db:"cluster_name"` ClusterName string `json:"clusterName" db:"cluster_name"`
Status string `json:"status" db:"status"` Status string `json:"status" db:"status"`
Remark string `json:"remark" db:"remark"` Remark string `json:"remark" db:"remark"`
InferUrl string `json:"inferUrl"`
}
)
type (
CommonResp {
Code int `json:"code,omitempty"`
Msg string `json:"msg,omitempty"`
Data interface{} `json:"data,omitempty"`
} }
) )

View File

@ -63,7 +63,17 @@ type (
clusterName string `json:"clusterName"` clusterName string `json:"clusterName"`
} }
/******************TextToText inference*************************/
TextToTextInferenceReq{
TaskName string `form:"taskName"`
TaskDesc string `form:"taskDesc"`
ModelName string `form:"modelName"`
ModelType string `form:"modelType"`
AdapterId string `form:"adapterId"`
AiClusterIds []string `form:"aiClusterIds"`
}
TextToTextInferenceResp{
}
) )

View File

@ -379,6 +379,12 @@ service pcm {
@handler createVisualizationJobHandler @handler createVisualizationJobHandler
post /ai/CreateVisualizationJob (CreateVisualizationJobReq) returns (CreateVisualizationJobResp) post /ai/CreateVisualizationJob (CreateVisualizationJobReq) returns (CreateVisualizationJobResp)
/******************Visualization Job Method start*************************/ /******************Visualization Job Method start*************************/
/***********chat***********/
@doc "文本识别"
@handler ProxyApiHandler
post /ai/chat (ChatReq) returns (CommonResp)
/******chat end***********/
} }
//screen接口 //screen接口
@ -920,6 +926,9 @@ service pcm {
group: inference group: inference
) )
service pcm { service pcm {
@handler TextToTextInferenceHandler
post /inference/text (TextToTextInferenceReq) returns (TextToTextInferenceResp)
@handler ImageInferenceHandler @handler ImageInferenceHandler
post /inference/images (ImageInferenceReq) returns (ImageInferenceResp) post /inference/images (ImageInferenceReq) returns (ImageInferenceResp)

View File

@ -58,7 +58,7 @@ func UpdateAiTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) {
list := make([]*types.TaskModel, len(tasklist)) list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist) copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- { for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed {
list = append(list[:i], list[i+1:]...) list = append(list[:i], list[i+1:]...)
} }
} }
@ -155,7 +155,7 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) {
list := make([]*types.TaskModel, len(tasklist)) list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist) copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- { for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed {
list = append(list[:i], list[i+1:]...) list = append(list[:i], list[i+1:]...)
} }
} }
@ -174,7 +174,7 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) {
} }
// Update Infer Task Status // Update Infer Task Status
if task.TaskTypeDict == 11 { if task.TaskTypeDict == "11" || task.TaskTypeDict == "12" {
UpdateInferTaskStatus(svc, task) UpdateInferTaskStatus(svc, task)
return return
} }
@ -229,30 +229,14 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) {
return return
} }
var start time.Time start, _ := time.ParseInLocation(constants.Layout, aiTask[0].StartTime, time.Local)
var end time.Time end, _ := time.ParseInLocation(constants.Layout, aiTask[0].EndTime, time.Local)
// distinguish train or infer temporarily
if task.TaskTypeDict == 11 {
start, _ = time.ParseInLocation(time.RFC3339, aiTask[0].StartTime, time.Local)
end, _ = time.ParseInLocation(time.RFC3339, aiTask[0].EndTime, time.Local)
} else {
start, _ = time.ParseInLocation(constants.Layout, aiTask[0].StartTime, time.Local)
end, _ = time.ParseInLocation(constants.Layout, aiTask[0].EndTime, time.Local)
}
var status string var status string
var count int var count int
for _, a := range aiTask { for _, a := range aiTask {
var s time.Time s, _ := time.ParseInLocation(constants.Layout, a.StartTime, time.Local)
var e time.Time e, _ := time.ParseInLocation(constants.Layout, a.EndTime, time.Local)
// distinguish train or infer temporarily
if task.TaskTypeDict == 11 {
s, _ = time.ParseInLocation(time.RFC3339, a.StartTime, time.Local)
e, _ = time.ParseInLocation(time.RFC3339, a.EndTime, time.Local)
} else {
s, _ = time.ParseInLocation(constants.Layout, a.StartTime, time.Local)
e, _ = time.ParseInLocation(constants.Layout, a.EndTime, time.Local)
}
if s.Before(start) { if s.Before(start) {
start = s start = s
@ -289,17 +273,10 @@ func UpdateTaskStatus(svc *svc.ServiceContext, tasklist []*types.TaskModel) {
if status != "" { if status != "" {
task.Status = status task.Status = status
// distinguish train or infer temporarily
if task.TaskTypeDict == 11 {
task.StartTime = start.Format(time.RFC3339)
task.EndTime = end.Format(time.RFC3339)
} else {
task.StartTime = start.Format(constants.Layout) task.StartTime = start.Format(constants.Layout)
task.EndTime = end.Format(constants.Layout) task.EndTime = end.Format(constants.Layout)
} }
}
task.UpdatedTime = time.Now().Format(constants.Layout) task.UpdatedTime = time.Now().Format(constants.Layout)
tx = svc.DbEngin.Table("task").Model(task).Updates(task) tx = svc.DbEngin.Table("task").Model(task).Updates(task)
if tx.Error != nil { if tx.Error != nil {
@ -561,6 +538,10 @@ func UpdateInferTaskStatus(svc *svc.ServiceContext, task *types.TaskModel) {
// return // return
//} //}
if aiTask[0].StartTime == "" {
return
}
start, _ := time.ParseInLocation(time.RFC3339, aiTask[0].StartTime, time.Local) start, _ := time.ParseInLocation(time.RFC3339, aiTask[0].StartTime, time.Local)
end, _ := time.ParseInLocation(time.RFC3339, aiTask[0].EndTime, time.Local) end, _ := time.ParseInLocation(time.RFC3339, aiTask[0].EndTime, time.Local)
var status string var status string

View File

@ -0,0 +1,24 @@
package ai
import (
"github.com/zeromicro/go-zero/rest/httpx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/ai"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
"net/http"
)
func ProxyApiHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.ChatReq
if err := httpx.Parse(r, &req); err != nil {
result.ParamErrorResult(r, w, err)
return
}
l := ai.NewProxyApiLogic(r.Context(), svcCtx)
resp, err := l.ProxyApi(&req, w)
result.HttpResult(r, w, resp, err)
}
}

View File

@ -0,0 +1,25 @@
package inference
import (
"github.com/zeromicro/go-zero/rest/httpx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/logic/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/repository/result"
"net/http"
)
func TextToTextInferenceHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.TextToTextInferenceReq
if err := httpx.Parse(r, &req); err != nil {
result.ParamErrorResult(r, w, err)
return
}
l := inference.NewTextToTextInferenceLogic(r.Context(), svcCtx)
resp, err := l.TextToTextInference(&req)
result.HttpResult(r, w, resp, err)
}
}

View File

@ -452,6 +452,11 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
Path: "/ai/CreateVisualizationJob", Path: "/ai/CreateVisualizationJob",
Handler: ai.CreateVisualizationJobHandler(serverCtx), Handler: ai.CreateVisualizationJobHandler(serverCtx),
}, },
{
Method: http.MethodPost,
Path: "/ai/chat",
Handler: ai.ProxyApiHandler(serverCtx),
},
}, },
rest.WithPrefix("/pcm/v1"), rest.WithPrefix("/pcm/v1"),
) )
@ -1153,6 +1158,11 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
server.AddRoutes( server.AddRoutes(
[]rest.Route{ []rest.Route{
{
Method: http.MethodPost,
Path: "/inference/text",
Handler: inference.TextToTextInferenceHandler(serverCtx),
},
{ {
Method: http.MethodPost, Method: http.MethodPost,
Path: "/inference/images", Path: "/inference/images",

View File

@ -0,0 +1,89 @@
package ai
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
tool "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils/hws"
"net/http"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type ProxyApiLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewProxyApiLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ProxyApiLogic {
return &ProxyApiLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
type ChatResult struct {
Results string `json:"results"`
}
type ResponseData struct {
Results string `json:"results"`
}
func (l *ProxyApiLogic) ProxyApi(req *types.ChatReq, w http.ResponseWriter) (resp *types.CommonResp, err error) {
jsonBytes, err := json.Marshal(&req.ReqData)
// 调用第三方接口的 POST 方法
thirdReq, err := http.NewRequest("POST", req.ApiUrl, bytes.NewBuffer(jsonBytes))
if err != nil {
return
}
signer := &hws.Signer{
Key: "UNEHPHO4Z7YSNPKRXFE4",
Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9",
}
if err := signer.Sign(thirdReq); err != nil {
return nil, err
}
// 设置client信任所有证书
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{
Transport: tr,
}
thirdReq.Header.Set("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52")
thirdReq.Header.Set("x-stage", "RELEASE")
thirdReq.Header.Set("Authorization", thirdReq.Header.Get(hws.HeaderXAuthorization))
thirdReq.Header.Set("X-Sdk-Date", thirdReq.Header.Get(hws.HeaderXDateTime))
thirdReq.Header.Set("Content-Type", "application/json")
thirdResp, err := client.Do(thirdReq)
defer thirdReq.Body.Close()
var responseData ResponseData
decoder := json.NewDecoder(thirdResp.Body)
if err := decoder.Decode(&responseData); err != nil {
fmt.Println("Error decoding response:", err)
}
chatResult := &ChatResult{}
tool.Convert(responseData, &chatResult)
return &types.CommonResp{
Code: thirdResp.StatusCode,
Msg: "success",
Data: chatResult,
}, nil
}

View File

@ -57,9 +57,9 @@ func (l *PageListTaskLogic) PageListTask(req *types.PageTaskReq) (resp *types.Pa
} }
// 更新智算任务状态 // 更新智算任务状态
chs := [2]chan struct{}{make(chan struct{}), make(chan struct{})} //chs := [2]chan struct{}{make(chan struct{}), make(chan struct{})}
go l.updateTaskStatus(list, chs[0]) //go l.updateTaskStatus(list, chs[0])
go l.updateAiTaskStatus(list, chs[1]) //go l.updateAiTaskStatus(list, chs[1])
for _, model := range list { for _, model := range list {
if model.StartTime != "" && model.EndTime == "" { if model.StartTime != "" && model.EndTime == "" {
@ -77,12 +77,12 @@ func (l *PageListTaskLogic) PageListTask(req *types.PageTaskReq) (resp *types.Pa
resp.PageNum = req.PageNum resp.PageNum = req.PageNum
resp.Total = total resp.Total = total
for _, ch := range chs { //for _, ch := range chs {
select { // select {
case <-ch: // case <-ch:
case <-time.After(1 * time.Second): // case <-time.After(1 * time.Second):
} // }
} //}
return return
} }
@ -90,7 +90,7 @@ func (l *PageListTaskLogic) updateTaskStatus(tasklist []*types.TaskModel, ch cha
list := make([]*types.TaskModel, len(tasklist)) list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist) copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- { for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed {
list = append(list[:i], list[i+1:]...) list = append(list[:i], list[i+1:]...)
} }
} }
@ -228,7 +228,7 @@ func (l *PageListTaskLogic) updateAiTaskStatus(tasklist []*types.TaskModel, ch c
list := make([]*types.TaskModel, len(tasklist)) list := make([]*types.TaskModel, len(tasklist))
copy(list, tasklist) copy(list, tasklist)
for i := len(list) - 1; i >= 0; i-- { for i := len(list) - 1; i >= 0; i-- {
if list[i].AdapterTypeDict != 1 || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed { if list[i].AdapterTypeDict != "1" || list[i].Status == constants.Succeeded || list[i].Status == constants.Failed {
list = append(list[:i], list[i+1:]...) list = append(list[:i], list[i+1:]...)
} }
} }

View File

@ -1,30 +1,15 @@
package inference package inference
import ( import (
"bytes"
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt"
"github.com/JCCE-nudt/apigw-go-sdk/core"
"github.com/go-resty/resty/v2"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"io"
"k8s.io/apimachinery/pkg/util/json"
"log"
"math/rand"
"mime/multipart"
"net/http" "net/http"
"strconv"
"sync"
"time"
) )
type ImageInferenceLogic struct { type ImageInferenceLogic struct {
@ -58,10 +43,7 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere
StaticWeightMap: req.StaticWeightMap, StaticWeightMap: req.StaticWeightMap,
} }
var ts []struct { var ts []*inference.ImageFile
imageResult *types.ImageResult
file multipart.File
}
uploadedFiles := r.MultipartForm.File uploadedFiles := r.MultipartForm.File
@ -81,14 +63,11 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere
defer file.Close() defer file.Close()
var ir types.ImageResult var ir types.ImageResult
ir.ImageName = header.Filename ir.ImageName = header.Filename
t := struct { t := inference.ImageFile{
imageResult *types.ImageResult ImageResult: &ir,
file multipart.File File: file,
}{
imageResult: &ir,
file: file,
} }
ts = append(ts, t) ts = append(ts, &t)
} }
_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId] _, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
@ -111,464 +90,32 @@ func (l *ImageInferenceLogic) ImageInfer(r *http.Request, req *types.ImageInfere
return nil, err return nil, err
} }
results, err := infer(opt, clusters, ts, l.svcCtx, l.ctx)
if err != nil {
return nil, err
}
resp.InferResults = results
return resp, nil
}
func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []struct {
imageResult *types.ImageResult
file multipart.File
}, svcCtx *svc.ServiceContext, ctx context.Context) ([]*types.ImageResult, error) {
if clusters == nil || len(clusters) == 0 { if clusters == nil || len(clusters) == 0 {
return nil, errors.New("clusters is nil") return nil, errors.New("clusters is nil")
} }
for i := len(clusters) - 1; i >= 0; i-- {
if clusters[i].Replicas == 0 {
clusters = append(clusters[:i], clusters[i+1:]...)
}
}
var wg sync.WaitGroup
var cluster_ch = make(chan struct {
urls []*collector.ImageInferUrl
clusterId string
clusterName string
imageNum int32
}, len(clusters))
var cs []struct {
urls []*collector.ImageInferUrl
clusterId string
clusterName string
imageNum int32
}
collectorMap := svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
//save task //save task
var synergystatus int64 var synergystatus int64
if len(clusters) > 1 { if len(clusters) > 1 {
synergystatus = 1 synergystatus = 1
} }
strategyCode, err := svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy) strategyCode, err := l.svcCtx.Scheduler.AiStorages.GetStrategyCode(opt.Strategy)
if err != nil { if err != nil {
return nil, err return nil, err
} }
adapterName, err := svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId) adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
id, err := svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11") id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "11")
if err != nil { if err != nil {
return nil, err return nil, err
} }
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中") l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "create", "任务创建中")
//save taskai go l.svcCtx.Scheduler.AiService.ImageInfer(opt, id, adapterName, clusters, ts, l.ctx)
for _, c := range clusters {
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
opt.Replica = c.Replicas
err := svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "")
if err != nil {
return nil, err
}
}
for _, cluster := range clusters { return resp, nil
wg.Add(1)
c := cluster
go func() {
imageUrls, err := collectorMap[c.ClusterId].GetImageInferUrl(ctx, opt)
if err != nil {
wg.Done()
return
}
clusterName, _ := svcCtx.Scheduler.AiStorages.GetClusterNameById(c.ClusterId)
s := struct {
urls []*collector.ImageInferUrl
clusterId string
clusterName string
imageNum int32
}{
urls: imageUrls,
clusterId: c.ClusterId,
clusterName: clusterName,
imageNum: c.Replicas,
}
cluster_ch <- s
wg.Done()
return
}()
}
wg.Wait()
close(cluster_ch)
for s := range cluster_ch {
cs = append(cs, s)
}
var aiTaskList []*models.TaskAi
tx := svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
if tx.Error != nil {
return nil, tx.Error
}
//no cluster available
if len(cs) == 0 {
for _, t := range aiTaskList {
t.Status = constants.Failed
t.EndTime = time.Now().Format(time.RFC3339)
err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
return nil, errors.New("image infer task failed")
}
//change cluster status
if len(clusters) != len(cs) {
var acs []*strategy.AssignedCluster
var rcs []*strategy.AssignedCluster
for _, cluster := range clusters {
if contains(cs, cluster.ClusterId) {
var ac *strategy.AssignedCluster
ac = cluster
rcs = append(rcs, ac)
} else {
var ac *strategy.AssignedCluster
ac = cluster
acs = append(acs, ac)
}
}
// update failed cluster status
for _, ac := range acs {
for _, t := range aiTaskList {
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Failed
t.EndTime = time.Now().Format(time.RFC3339)
err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
}
}
// update running cluster status
for _, ac := range rcs {
for _, t := range aiTaskList {
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Running
err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
}
}
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
} else {
for _, t := range aiTaskList {
t.Status = constants.Running
err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中")
}
var result_ch = make(chan *types.ImageResult, len(ts))
var results []*types.ImageResult
limit := make(chan bool, 7)
var imageNumIdx int32 = 0
var imageNumIdxEnd int32 = 0
for _, c := range cs {
new_images := make([]struct {
imageResult *types.ImageResult
file multipart.File
}, len(ts))
copy(new_images, ts)
imageNumIdxEnd = imageNumIdxEnd + c.imageNum
new_images = new_images[imageNumIdx:imageNumIdxEnd]
imageNumIdx = imageNumIdx + c.imageNum
wg.Add(len(new_images))
go sendInferReq(new_images, c, &wg, result_ch, limit)
}
wg.Wait()
close(result_ch)
for s := range result_ch {
results = append(results, s)
}
//sort.Slice(results, func(p, q int) bool {
// return results[p].ClusterName < results[q].ClusterName
//})
//save ai sub tasks
for _, r := range results {
for _, task := range aiTaskList {
if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
taskAiSub := models.TaskAiSub{
TaskId: id,
TaskName: task.Name,
TaskAiId: task.TaskId,
TaskAiName: task.Name,
ImageName: r.ImageName,
Result: r.ImageResult,
Card: r.Card,
ClusterId: task.ClusterId,
ClusterName: r.ClusterName,
}
tx := svcCtx.DbEngin.Table("task_ai_sub").Create(&taskAiSub)
if tx.Error != nil {
logx.Errorf(err.Error())
}
}
}
}
// update succeeded cluster status
var successStatusCount int
for _, c := range cs {
for _, t := range aiTaskList {
if c.clusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Completed
t.EndTime = time.Now().Format(time.RFC3339)
err := svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
successStatusCount++
} else {
continue
}
}
}
if len(cs) == successStatusCount {
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成")
} else {
svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
}
return results, nil
}
func sendInferReq(images []struct {
imageResult *types.ImageResult
file multipart.File
}, cluster struct {
urls []*collector.ImageInferUrl
clusterId string
clusterName string
imageNum int32
}, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) {
for _, image := range images {
limit <- true
go func(t struct {
imageResult *types.ImageResult
file multipart.File
}, c struct {
urls []*collector.ImageInferUrl
clusterId string
clusterName string
imageNum int32
}) {
if len(c.urls) == 1 {
r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName, c.clusterName)
if err != nil {
t.imageResult.ImageResult = err.Error()
t.imageResult.ClusterId = c.clusterId
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[0].Card
ch <- t.imageResult
wg.Done()
<-limit
return
}
t.imageResult.ImageResult = r
t.imageResult.ClusterId = c.clusterId
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[0].Card
ch <- t.imageResult
wg.Done()
<-limit
return
} else {
idx := rand.Intn(len(c.urls))
r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName, c.clusterName)
if err != nil {
t.imageResult.ImageResult = err.Error()
t.imageResult.ClusterId = c.clusterId
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[idx].Card
ch <- t.imageResult
wg.Done()
<-limit
return
}
t.imageResult.ImageResult = r
t.imageResult.ClusterId = c.clusterId
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[idx].Card
ch <- t.imageResult
wg.Done()
<-limit
return
}
}(image, cluster)
<-limit
}
}
func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) {
if clusterName == "鹏城云脑II-modelarts" {
r, err := getInferResultModelarts(url, file, fileName)
if err != nil {
return "", err
}
return r, nil
}
var res Res
req := GetRestyRequest(20)
_, err := req.
SetFileReader("file", fileName, file).
SetResult(&res).
Post(url)
if err != nil {
return "", err
}
return res.Result, nil
}
func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) {
var res Res
body, err := SendRequest("POST", url, file, fileName)
if err != nil {
return "", err
}
errjson := json.Unmarshal([]byte(body), &res)
if errjson != nil {
log.Fatalf("Error parsing JSON: %s", errjson)
}
return res.Result, nil
}
// SignClient AK/SK签名认证
func SignClient(r *http.Request, writer *multipart.Writer) (*http.Client, error) {
r.Header.Add("content-type", "application/json;charset=UTF-8")
r.Header.Add("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52")
r.Header.Add("x-stage", "RELEASE")
r.Header.Add("x-sdk-content-sha256", "UNSIGNED-PAYLOAD")
r.Header.Set("Content-Type", writer.FormDataContentType())
s := core.Signer{
Key: "UNEHPHO4Z7YSNPKRXFE4",
Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9",
}
err := s.Sign(r)
if err != nil {
return nil, err
}
//设置client信任所有证书
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{
Transport: tr,
}
return client, nil
}
func SendRequest(method, url string, file multipart.File, fileName string) (string, error) {
/*body := &bytes.Buffer{}
writer := multipart.NewWriter(body)*/
// 创建一个新的缓冲区以写入multipart表单
var body bytes.Buffer
// 创建一个新的multipart writer
writer := multipart.NewWriter(&body)
// 创建一个用于写入文件的表单字段
part, err := writer.CreateFormFile("file", fileName) // "file"是表单的字段名,第二个参数是文件名
if err != nil {
fmt.Println("Error creating form file:", err)
}
// 将文件的内容拷贝到multipart writer中
_, err = io.Copy(part, file)
if err != nil {
fmt.Println("Error copying file data:", err)
}
err = writer.Close()
if err != nil {
fmt.Println("Error closing multipart writer:", err)
}
request, err := http.NewRequest(method, url, &body)
if err != nil {
fmt.Println("Error creating new request:", err)
//return nil, err
}
signedR, err := SignClient(request, writer)
if err != nil {
fmt.Println("Error signing request:", err)
//return nil, err
}
res, err := signedR.Do(request)
if err != nil {
fmt.Println("Error sending request:", err)
return "", err
}
//defer res.Body.Close()
Resbody, err := io.ReadAll(res.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
//return nil, err
}
return string(Resbody), nil
}
func GetRestyRequest(timeoutSeconds int64) *resty.Request {
client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second)
request := client.R()
return request
}
type Res struct {
Result string `json:"result"`
}
func contains(cs []struct {
urls []*collector.ImageInferUrl
clusterId string
clusterName string
imageNum int32
}, e string) bool {
for _, c := range cs {
if c.clusterId == e {
return true
}
}
return false
} }

View File

@ -0,0 +1,150 @@
package inference
import (
"context"
"errors"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/svc"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"strconv"
"sync"
"time"
)
type TextToTextInferenceLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewTextToTextInferenceLogic(ctx context.Context, svcCtx *svc.ServiceContext) *TextToTextInferenceLogic {
return &TextToTextInferenceLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *TextToTextInferenceLogic) TextToTextInference(req *types.TextToTextInferenceReq) (resp *types.TextToTextInferenceResp, err error) {
resp = &types.TextToTextInferenceResp{}
opt := &option.InferOption{
TaskName: req.TaskName,
TaskDesc: req.TaskDesc,
AdapterId: req.AdapterId,
AiClusterIds: req.AiClusterIds,
ModelName: req.ModelName,
ModelType: req.ModelType,
}
_, ok := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
if !ok {
return nil, errors.New("AdapterId does not exist")
}
//save task
var synergystatus int64
var strategyCode int64
adapterName, err := l.svcCtx.Scheduler.AiStorages.GetAdapterNameById(opt.AdapterId)
if err != nil {
return nil, err
}
id, err := l.svcCtx.Scheduler.AiStorages.SaveTask(opt.TaskName, strategyCode, synergystatus, "12")
if err != nil {
return nil, err
}
var wg sync.WaitGroup
var cluster_ch = make(chan struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}, len(opt.AiClusterIds))
var cs []struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}
collectorMap := l.svcCtx.Scheduler.AiService.AiCollectorAdapterMap[opt.AdapterId]
//save taskai
for _, clusterId := range opt.AiClusterIds {
wg.Add(1)
go func(cId string) {
urls, err := collectorMap[cId].GetInferUrl(l.ctx, opt)
if err != nil {
wg.Done()
return
}
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(cId)
s := struct {
urls []*collector.InferUrl
clusterId string
clusterName string
}{
urls: urls,
clusterId: cId,
clusterName: clusterName,
}
cluster_ch <- s
wg.Done()
return
}(clusterId)
}
wg.Wait()
close(cluster_ch)
for s := range cluster_ch {
cs = append(cs, s)
}
if len(cs) == 0 {
clusterId := opt.AiClusterIds[0]
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(opt.AiClusterIds[0])
err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, clusterId, clusterName, "", constants.Failed, "")
if err != nil {
return nil, err
}
l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
}
for _, c := range cs {
clusterName, _ := l.svcCtx.Scheduler.AiStorages.GetClusterNameById(c.clusterId)
err := l.svcCtx.Scheduler.AiStorages.SaveAiTask(id, opt, adapterName, c.clusterId, clusterName, "", constants.Saved, "")
if err != nil {
return nil, err
}
}
var aiTaskList []*models.TaskAi
tx := l.svcCtx.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
if tx.Error != nil {
return nil, tx.Error
}
for i, t := range aiTaskList {
if strconv.Itoa(int(t.ClusterId)) == cs[i].clusterId {
t.Status = constants.Completed
t.EndTime = time.Now().Format(time.RFC3339)
url := cs[i].urls[0].Url + storeLink.FORWARD_SLASH + "chat"
t.InferUrl = url
err := l.svcCtx.Scheduler.AiStorages.UpdateAiTask(t)
if err != nil {
logx.Errorf(tx.Error.Error())
}
}
}
l.svcCtx.Scheduler.AiStorages.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成")
return resp, nil
}

View File

@ -94,6 +94,15 @@ func (s *AiStorage) GetAiTasksByAdapterId(adapterId string) ([]*models.TaskAi, e
return resp, nil return resp, nil
} }
func (s *AiStorage) GetAiTaskListById(id int64) ([]*models.TaskAi, error) {
var aiTaskList []*models.TaskAi
tx := s.DbEngin.Raw("select * from task_ai where `task_id` = ? ", id).Scan(&aiTaskList)
if tx.Error != nil {
return nil, tx.Error
}
return aiTaskList, nil
}
func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64, aiType string) (int64, error) { func (s *AiStorage) SaveTask(name string, strategyCode int64, synergyStatus int64, aiType string) (int64, error) {
startTime := time.Now() startTime := time.Now()
// 构建主任务结构体 // 构建主任务结构体
@ -165,6 +174,14 @@ func (s *AiStorage) SaveAiTask(taskId int64, opt option.Option, adapterName stri
return nil return nil
} }
func (s *AiStorage) SaveAiTaskImageSubTask(ta *models.TaskAiSub) error {
tx := s.DbEngin.Table("task_ai_sub").Create(ta)
if tx.Error != nil {
return tx.Error
}
return nil
}
func (s *AiStorage) SaveClusterTaskQueue(adapterId string, clusterId string, queueNum int64) error { func (s *AiStorage) SaveClusterTaskQueue(adapterId string, clusterId string, queueNum int64) error {
aId, err := strconv.ParseInt(adapterId, 10, 64) aId, err := strconv.ParseInt(adapterId, 10, 64)
if err != nil { if err != nil {

View File

@ -1,12 +1,17 @@
package service package service
import ( import (
"context"
"fmt"
"github.com/zeromicro/go-zero/zrpc" "github.com/zeromicro/go-zero/zrpc"
hpcacclient "gitlink.org.cn/JointCloud/pcm-ac/hpcacclient" hpcacclient "gitlink.org.cn/JointCloud/pcm-ac/hpcacclient"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/config" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/config"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/database" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/database"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/executor" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/executor"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-modelarts/client/imagesservice" "gitlink.org.cn/JointCloud/pcm-modelarts/client/imagesservice"
@ -86,6 +91,15 @@ func InitAiClusterMap(conf *config.Config, clusters []types.ClusterInfo) (map[st
return executorMap, collectorMap return executorMap, collectorMap
} }
func (as *AiService) ImageInfer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*inference.ImageFile, ctx context.Context) {
res, err := inference.Infer(opt, id, adapterName, clusters, ts, as.AiCollectorAdapterMap, as.Storage, ctx)
if err != nil {
return
}
fmt.Println(res)
}
//func (a *AiService) AddCluster() error { //func (a *AiService) AddCluster() error {
// //
//} //}

View File

@ -15,10 +15,10 @@ type AiCollector interface {
UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error UploadAlgorithmCode(ctx context.Context, resourceType string, card string, taskType string, dataset string, algorithm string, code string) error
GetComputeCards(ctx context.Context) ([]string, error) GetComputeCards(ctx context.Context) ([]string, error)
GetUserBalance(ctx context.Context) (float64, error) GetUserBalance(ctx context.Context) (float64, error)
GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*ImageInferUrl, error) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*InferUrl, error)
} }
type ImageInferUrl struct { type InferUrl struct {
Url string Url string
Card string Card string
} }

View File

@ -0,0 +1,385 @@
package inference
import (
"context"
"encoding/json"
"errors"
"github.com/go-resty/resty/v2"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/database"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/service/collector"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/strategy"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
"log"
"math/rand"
"mime/multipart"
"sort"
"strconv"
"sync"
"time"
)
type ImageFile struct {
ImageResult *types.ImageResult
File multipart.File
}
func Infer(opt *option.InferOption, id int64, adapterName string, clusters []*strategy.AssignedCluster, ts []*ImageFile, aiCollectorAdapterMap map[string]map[string]collector.AiCollector, storage *database.AiStorage, ctx context.Context) ([]*types.ImageResult, error) {
for i := len(clusters) - 1; i >= 0; i-- {
if clusters[i].Replicas == 0 {
clusters = append(clusters[:i], clusters[i+1:]...)
}
}
var wg sync.WaitGroup
var cluster_ch = make(chan struct {
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
}, len(clusters))
var cs []struct {
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
}
collectorMap := aiCollectorAdapterMap[opt.AdapterId]
//save taskai
for _, c := range clusters {
clusterName, _ := storage.GetClusterNameById(c.ClusterId)
opt.Replica = c.Replicas
err := storage.SaveAiTask(id, opt, adapterName, c.ClusterId, clusterName, "", constants.Saved, "")
if err != nil {
return nil, err
}
}
for _, cluster := range clusters {
wg.Add(1)
c := cluster
go func() {
imageUrls, err := collectorMap[c.ClusterId].GetInferUrl(ctx, opt)
for i, _ := range imageUrls {
imageUrls[i].Url = imageUrls[i].Url + storeLink.FORWARD_SLASH + "image"
}
if err != nil {
wg.Done()
return
}
clusterName, _ := storage.GetClusterNameById(c.ClusterId)
s := struct {
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
}{
urls: imageUrls,
clusterId: c.ClusterId,
clusterName: clusterName,
imageNum: c.Replicas,
}
cluster_ch <- s
wg.Done()
return
}()
}
wg.Wait()
close(cluster_ch)
for s := range cluster_ch {
cs = append(cs, s)
}
aiTaskList, err := storage.GetAiTaskListById(id)
if err != nil {
return nil, err
}
//no cluster available
if len(cs) == 0 {
for _, t := range aiTaskList {
t.Status = constants.Failed
t.EndTime = time.Now().Format(time.RFC3339)
err := storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
}
storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
return nil, errors.New("image infer task failed")
}
//change cluster status
if len(clusters) != len(cs) {
var acs []*strategy.AssignedCluster
var rcs []*strategy.AssignedCluster
for _, cluster := range clusters {
if contains(cs, cluster.ClusterId) {
var ac *strategy.AssignedCluster
ac = cluster
rcs = append(rcs, ac)
} else {
var ac *strategy.AssignedCluster
ac = cluster
acs = append(acs, ac)
}
}
// update failed cluster status
for _, ac := range acs {
for _, t := range aiTaskList {
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Failed
t.EndTime = time.Now().Format(time.RFC3339)
err := storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
}
}
}
// update running cluster status
for _, ac := range rcs {
for _, t := range aiTaskList {
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Running
err := storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
}
}
}
storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
} else {
for _, t := range aiTaskList {
t.Status = constants.Running
err := storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
}
storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "running", "任务运行中")
}
var result_ch = make(chan *types.ImageResult, len(ts))
var results []*types.ImageResult
limit := make(chan bool, 7)
var imageNumIdx int32 = 0
var imageNumIdxEnd int32 = 0
for _, c := range cs {
new_images := make([]*ImageFile, len(ts))
copy(new_images, ts)
imageNumIdxEnd = imageNumIdxEnd + c.imageNum
new_images = new_images[imageNumIdx:imageNumIdxEnd]
imageNumIdx = imageNumIdx + c.imageNum
wg.Add(len(new_images))
go sendInferReq(new_images, c, &wg, result_ch, limit)
}
wg.Wait()
close(result_ch)
for s := range result_ch {
results = append(results, s)
}
sort.Slice(results, func(p, q int) bool {
return results[p].ClusterName < results[q].ClusterName
})
//save ai sub tasks
for _, r := range results {
for _, task := range aiTaskList {
if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
taskAiSub := models.TaskAiSub{
TaskId: id,
TaskName: task.Name,
TaskAiId: task.TaskId,
TaskAiName: task.Name,
ImageName: r.ImageName,
Result: r.ImageResult,
Card: r.Card,
ClusterId: task.ClusterId,
ClusterName: r.ClusterName,
}
err := storage.SaveAiTaskImageSubTask(&taskAiSub)
if err != nil {
panic(err)
}
}
}
}
// update succeeded cluster status
var successStatusCount int
for _, c := range cs {
for _, t := range aiTaskList {
if c.clusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Completed
t.EndTime = time.Now().Format(time.RFC3339)
err := storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
successStatusCount++
} else {
continue
}
}
}
if len(cs) == successStatusCount {
storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "completed", "任务完成")
} else {
storage.AddNoticeInfo(opt.AdapterId, adapterName, "", "", opt.TaskName, "failed", "任务失败")
}
return results, nil
}
func sendInferReq(images []*ImageFile, cluster struct {
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
}, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) {
for _, image := range images {
limit <- true
go func(t *ImageFile, c struct {
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
}) {
if len(c.urls) == 1 {
r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterName)
if err != nil {
t.ImageResult.ImageResult = err.Error()
t.ImageResult.ClusterId = c.clusterId
t.ImageResult.ClusterName = c.clusterName
t.ImageResult.Card = c.urls[0].Card
ch <- t.ImageResult
wg.Done()
<-limit
return
}
t.ImageResult.ImageResult = r
t.ImageResult.ClusterId = c.clusterId
t.ImageResult.ClusterName = c.clusterName
t.ImageResult.Card = c.urls[0].Card
ch <- t.ImageResult
wg.Done()
<-limit
return
} else {
idx := rand.Intn(len(c.urls))
r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterName)
if err != nil {
t.ImageResult.ImageResult = err.Error()
t.ImageResult.ClusterId = c.clusterId
t.ImageResult.ClusterName = c.clusterName
t.ImageResult.Card = c.urls[idx].Card
ch <- t.ImageResult
wg.Done()
<-limit
return
}
t.ImageResult.ImageResult = r
t.ImageResult.ClusterId = c.clusterId
t.ImageResult.ClusterName = c.clusterName
t.ImageResult.Card = c.urls[idx].Card
ch <- t.ImageResult
wg.Done()
<-limit
return
}
}(image, cluster)
<-limit
}
}
func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) {
if clusterName == "鹏城云脑II-modelarts" {
r, err := getInferResultModelarts(url, file, fileName)
if err != nil {
return "", err
}
return r, nil
}
var res Res
req := GetRestyRequest(20)
_, err := req.
SetFileReader("file", fileName, file).
SetResult(&res).
Post(url)
if err != nil {
return "", err
}
return res.Result, nil
}
func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) {
var res Res
/* req := GetRestyRequest(20)
_, err := req.
SetFileReader("file", fileName, file).
SetHeaders(map[string]string{
"ak": "UNEHPHO4Z7YSNPKRXFE4",
"sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9",
}).
SetResult(&res).
Post(url)
if err != nil {
return "", err
}*/
body, err := utils.SendRequest("POST", url, file, fileName)
if err != nil {
return "", err
}
errjson := json.Unmarshal([]byte(body), &res)
if errjson != nil {
log.Fatalf("Error parsing JSON: %s", errjson)
}
return res.Result, nil
}
func GetRestyRequest(timeoutSeconds int64) *resty.Request {
client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second)
request := client.R()
return request
}
type Res struct {
Result string `json:"result"`
}
func contains(cs []struct {
urls []*collector.InferUrl
clusterId string
clusterName string
imageNum int32
}, e string) bool {
for _, c := range cs {
if c.clusterId == e {
return true
}
}
return false
}

View File

@ -378,8 +378,8 @@ func (m *ModelArtsLink) generateAlgorithmId(ctx context.Context, option *option.
return errors.New("failed to get AlgorithmId") return errors.New("failed to get AlgorithmId")
} }
func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (m *ModelArtsLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.InferUrl
urlReq := &modelartsclient.ImageReasoningUrlReq{ urlReq := &modelartsclient.ImageReasoningUrlReq{
ModelName: option.ModelName, ModelName: option.ModelName,
Type: option.ModelType, Type: option.ModelType,
@ -389,7 +389,7 @@ func (m *ModelArtsLink) GetImageInferUrl(ctx context.Context, option *option.Inf
if err != nil { if err != nil {
return nil, err return nil, err
} }
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.InferUrl{
Url: urlResp.Url, Url: urlResp.Url,
Card: "npu", Card: "npu",
} }

View File

@ -871,7 +871,7 @@ func setResourceIdByCard(option *option.AiOption, specs *octopus.GetResourceSpec
return errors.New("set ResourceId error") return errors.New("set ResourceId error")
} }
func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (o *OctopusLink) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
req := &octopus.GetNotebookListReq{ req := &octopus.GetNotebookListReq{
Platform: o.platform, Platform: o.platform,
PageIndex: o.pageIndex, PageIndex: o.pageIndex,
@ -882,13 +882,13 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer
return nil, err return nil, err
} }
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.InferUrl
for _, notebook := range list.Payload.GetNotebooks() { for _, notebook := range list.Payload.GetNotebooks() {
if strings.Contains(notebook.AlgorithmName, option.ModelName) { if strings.Contains(notebook.AlgorithmName, option.ModelName) && notebook.Status == "running" {
url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1) url := strings.Replace(notebook.Tasks[0].Url, FORWARD_SLASH, "", -1)
names := strings.Split(notebook.AlgorithmName, UNDERSCORE) names := strings.Split(notebook.AlgorithmName, UNDERSCORE)
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.InferUrl{
Url: DOMAIN + url + FORWARD_SLASH + "image", Url: DOMAIN + url,
Card: names[2], Card: names[2],
} }
imageUrls = append(imageUrls, imageUrl) imageUrls = append(imageUrls, imageUrl)
@ -896,5 +896,9 @@ func (o *OctopusLink) GetImageInferUrl(ctx context.Context, option *option.Infer
continue continue
} }
} }
if len(imageUrls) == 0 {
return nil, errors.New("no infer url available")
}
return imageUrls, nil return imageUrls, nil
} }

View File

@ -730,8 +730,8 @@ func (s *ShuguangAi) generateParams(option *option.AiOption) error {
return errors.New("failed to set params") return errors.New("failed to set params")
} }
func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.ImageInferUrl, error) { func (s *ShuguangAi) GetInferUrl(ctx context.Context, option *option.InferOption) ([]*collector.InferUrl, error) {
var imageUrls []*collector.ImageInferUrl var imageUrls []*collector.InferUrl
urlReq := &hpcAC.GetInferUrlReq{ urlReq := &hpcAC.GetInferUrlReq{
ModelName: option.ModelName, ModelName: option.ModelName,
@ -743,7 +743,7 @@ func (s *ShuguangAi) GetImageInferUrl(ctx context.Context, option *option.InferO
if err != nil { if err != nil {
return nil, err return nil, err
} }
imageUrl := &collector.ImageInferUrl{ imageUrl := &collector.InferUrl{
Url: urlResp.Url, Url: urlResp.Url,
Card: "dcu", Card: "dcu",
} }

View File

@ -78,6 +78,7 @@ var (
} }
ModelTypeMap = map[string][]string{ ModelTypeMap = map[string][]string{
"image_recognition": {"imagenet_resnet50"}, "image_recognition": {"imagenet_resnet50"},
"text_to_text": {"chatGLM_6B"},
} }
AITYPE = map[string]string{ AITYPE = map[string]string{
"1": OCTOPUS, "1": OCTOPUS,

View File

@ -347,8 +347,8 @@ type TaskModel struct {
TenantId string `json:"tenantId,omitempty" db:"tenant_id"` TenantId string `json:"tenantId,omitempty" db:"tenant_id"`
CreatedTime string `json:"createdTime,omitempty" db:"created_time" gorm:"autoCreateTime"` CreatedTime string `json:"createdTime,omitempty" db:"created_time" gorm:"autoCreateTime"`
UpdatedTime string `json:"updatedTime,omitempty" db:"updated_time"` UpdatedTime string `json:"updatedTime,omitempty" db:"updated_time"`
AdapterTypeDict int `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值 AdapterTypeDict string `json:"adapterTypeDict" db:"adapter_type_dict" gorm:"adapter_type_dict"` //适配器类型(对应字典表的值
TaskTypeDict int `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值 TaskTypeDict string `json:"taskTypeDict" db:"task_type_dict" gorm:"task_type_dict"` //任务类型(对应字典表的值
} }
type TaskDetailReq struct { type TaskDetailReq struct {
@ -1205,6 +1205,13 @@ type SubTaskInfo struct {
ClusterName string `json:"clusterName" db:"cluster_name"` ClusterName string `json:"clusterName" db:"cluster_name"`
Status string `json:"status" db:"status"` Status string `json:"status" db:"status"`
Remark string `json:"remark" db:"remark"` Remark string `json:"remark" db:"remark"`
InferUrl string `json:"inferUrl"`
}
type CommonResp struct {
Code int `json:"code,omitempty"`
Msg string `json:"msg,omitempty"`
Data interface{} `json:"data,omitempty"`
} }
type CommitHpcTaskReq struct { type CommitHpcTaskReq struct {
@ -2895,6 +2902,12 @@ type AiTask struct {
TimeElapsed int32 `json:"elapsed,optional"` TimeElapsed int32 `json:"elapsed,optional"`
} }
type ChatReq struct {
ApiUrl string `json:"apiUrl,optional"`
Method string `json:"method,optional"`
ReqData map[string]interface{} `json:"reqData"`
}
type StorageScreenReq struct { type StorageScreenReq struct {
} }
@ -5934,3 +5947,15 @@ type InferenceResult struct {
Card string `json:"card"` Card string `json:"card"`
ClusterName string `json:"clusterName"` ClusterName string `json:"clusterName"`
} }
type TextToTextInferenceReq struct {
TaskName string `form:"taskName"`
TaskDesc string `form:"taskDesc"`
ModelName string `form:"modelName"`
ModelType string `form:"modelType"`
AdapterId string `form:"adapterId"`
AiClusterIds []string `form:"aiClusterIds"`
}
type TextToTextInferenceResp struct {
}

View File

@ -54,6 +54,7 @@ type (
TaskType string `db:"task_type"` TaskType string `db:"task_type"`
DeletedAt *time.Time `db:"deleted_at"` DeletedAt *time.Time `db:"deleted_at"`
Card string `db:"card"` Card string `db:"card"`
InferUrl string `db:"infer_url"`
} }
) )

84
pkg/utils/aksk_sign.go Normal file
View File

@ -0,0 +1,84 @@
package utils
import (
"bytes"
"crypto/tls"
"fmt"
"github.com/JCCE-nudt/apigw-go-sdk/core"
"io"
"mime/multipart"
"net/http"
)
// SignClient AK/SK签名认证
func SignClient(r *http.Request, writer *multipart.Writer) (*http.Client, error) {
r.Header.Add("content-type", "application/json;charset=UTF-8")
r.Header.Add("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52")
r.Header.Add("x-stage", "RELEASE")
r.Header.Add("x-sdk-content-sha256", "UNSIGNED-PAYLOAD")
r.Header.Set("Content-Type", writer.FormDataContentType())
s := core.Signer{
Key: "UNEHPHO4Z7YSNPKRXFE4",
Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9",
}
err := s.Sign(r)
if err != nil {
return nil, err
}
//设置client信任所有证书
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{
Transport: tr,
}
return client, nil
}
func SendRequest(method, url string, file multipart.File, fileName string) (string, error) {
/*body := &bytes.Buffer{}
writer := multipart.NewWriter(body)*/
// 创建一个新的缓冲区以写入multipart表单
var body bytes.Buffer
// 创建一个新的multipart writer
writer := multipart.NewWriter(&body)
// 创建一个用于写入文件的表单字段
part, err := writer.CreateFormFile("file", fileName) // "file"是表单的字段名,第二个参数是文件名
if err != nil {
fmt.Println("Error creating form file:", err)
}
// 将文件的内容拷贝到multipart writer中
_, err = io.Copy(part, file)
if err != nil {
fmt.Println("Error copying file data:", err)
}
err = writer.Close()
if err != nil {
fmt.Println("Error closing multipart writer:", err)
}
request, err := http.NewRequest(method, url, &body)
if err != nil {
fmt.Println("Error creating new request:", err)
//return nil, err
}
signedR, err := SignClient(request, writer)
if err != nil {
fmt.Println("Error signing request:", err)
//return nil, err
}
res, err := signedR.Do(request)
if err != nil {
fmt.Println("Error sending request:", err)
return "", err
}
//defer res.Body.Close()
Resbody, err := io.ReadAll(res.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
//return nil, err
}
return string(Resbody), nil
}

38
pkg/utils/hws/escape.go Normal file
View File

@ -0,0 +1,38 @@
package hws
func shouldEscape(c byte) bool {
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' || c == '_' || c == '-' || c == '~' || c == '.' {
return false
}
return true
}
func Escape(s string) string {
hexCount := 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c) {
hexCount++
}
}
if hexCount == 0 {
return s
}
t := make([]byte, len(s)+2*hexCount)
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case shouldEscape(c):
t[j] = '%'
t[j+1] = "0123456789ABCDEF"[c>>4]
t[j+2] = "0123456789ABCDEF"[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}

184
pkg/utils/hws/signer.go Normal file
View File

@ -0,0 +1,184 @@
package hws
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"fmt"
"io/ioutil"
"net/http"
"sort"
"strings"
"time"
)
const (
DateFormat = "20060102T150405Z"
SignAlgorithm = "SDK-HMAC-SHA256"
HeaderXDateTime = "X-Sdk-Date"
HeaderXHost = "host"
HeaderXAuthorization = "Authorization"
HeaderXContentSha256 = "X-Sdk-Content-Sha256"
)
func hmacsha256(keyByte []byte, dataStr string) ([]byte, error) {
hm := hmac.New(sha256.New, []byte(keyByte))
if _, err := hm.Write([]byte(dataStr)); err != nil {
return nil, err
}
return hm.Sum(nil), nil
}
func CanonicalRequest(request *http.Request, signedHeaders []string) (string, error) {
var hexencode string
var err error
if hex := request.Header.Get(HeaderXContentSha256); hex != "" {
hexencode = hex
} else {
bodyData, err := RequestPayload(request)
if err != nil {
return "", err
}
hexencode, err = HexEncodeSHA256Hash(bodyData)
if err != nil {
return "", err
}
}
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", request.Method, CanonicalURI(request), CanonicalQueryString(request), CanonicalHeaders(request, signedHeaders), strings.Join(signedHeaders, ";"), hexencode), err
}
func CanonicalURI(request *http.Request) string {
pattens := strings.Split(request.URL.Path, "/")
var uriSlice []string
for _, v := range pattens {
uriSlice = append(uriSlice, Escape(v))
}
urlpath := strings.Join(uriSlice, "/")
if len(urlpath) == 0 || urlpath[len(urlpath)-1] != '/' {
urlpath = urlpath + "/"
}
return urlpath
}
func CanonicalQueryString(request *http.Request) string {
var keys []string
queryMap := request.URL.Query()
for key := range queryMap {
keys = append(keys, key)
}
sort.Strings(keys)
var query []string
for _, key := range keys {
k := Escape(key)
sort.Strings(queryMap[key])
for _, v := range queryMap[key] {
kv := fmt.Sprintf("%s=%s", k, Escape(v))
query = append(query, kv)
}
}
queryStr := strings.Join(query, "&")
request.URL.RawQuery = queryStr
return queryStr
}
func CanonicalHeaders(request *http.Request, signerHeaders []string) string {
var canonicalHeaders []string
header := make(map[string][]string)
for k, v := range request.Header {
header[strings.ToLower(k)] = v
}
for _, key := range signerHeaders {
value := header[key]
if strings.EqualFold(key, HeaderXHost) {
value = []string{request.Host}
}
sort.Strings(value)
for _, v := range value {
canonicalHeaders = append(canonicalHeaders, key+":"+strings.TrimSpace(v))
}
}
return fmt.Sprintf("%s\n", strings.Join(canonicalHeaders, "\n"))
}
func SignedHeaders(r *http.Request) []string {
var signedHeaders []string
for key := range r.Header {
signedHeaders = append(signedHeaders, strings.ToLower(key))
}
sort.Strings(signedHeaders)
return signedHeaders
}
func RequestPayload(request *http.Request) ([]byte, error) {
if request.Body == nil {
return []byte(""), nil
}
bodyByte, err := ioutil.ReadAll(request.Body)
if err != nil {
return []byte(""), err
}
request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyByte))
return bodyByte, err
}
func StringToSign(canonicalRequest string, t time.Time) (string, error) {
hashStruct := sha256.New()
_, err := hashStruct.Write([]byte(canonicalRequest))
if err != nil {
return "", err
}
return fmt.Sprintf("%s\n%s\n%x",
SignAlgorithm, t.UTC().Format(DateFormat), hashStruct.Sum(nil)), nil
}
func SignStringToSign(stringToSign string, signingKey []byte) (string, error) {
hmsha, err := hmacsha256(signingKey, stringToSign)
return fmt.Sprintf("%x", hmsha), err
}
func HexEncodeSHA256Hash(body []byte) (string, error) {
hashStruct := sha256.New()
if len(body) == 0 {
body = []byte("")
}
_, err := hashStruct.Write(body)
return fmt.Sprintf("%x", hashStruct.Sum(nil)), err
}
func AuthHeaderValue(signatureStr, accessKeyStr string, signedHeaders []string) string {
return fmt.Sprintf("%s Access=%s, SignedHeaders=%s, Signature=%s", SignAlgorithm, accessKeyStr, strings.Join(signedHeaders, ";"), signatureStr)
}
type Signer struct {
Key string
Secret string
}
func (s *Signer) Sign(request *http.Request) error {
var t time.Time
var err error
var date string
if date = request.Header.Get(HeaderXDateTime); date != "" {
t, err = time.Parse(DateFormat, date)
}
if err != nil || date == "" {
t = time.Now()
request.Header.Set(HeaderXDateTime, t.UTC().Format(DateFormat))
}
signedHeaders := SignedHeaders(request)
canonicalRequest, err := CanonicalRequest(request, signedHeaders)
if err != nil {
return err
}
stringToSignStr, err := StringToSign(canonicalRequest, t)
if err != nil {
return err
}
signatureStr, err := SignStringToSign(stringToSignStr, []byte(s.Secret))
if err != nil {
return err
}
authValueStr := AuthHeaderValue(signatureStr, s.Key, signedHeaders)
request.Header.Set(HeaderXAuthorization, authValueStr)
return nil
}