534 lines
14 KiB
Go
534 lines
14 KiB
Go
package imageInference
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"github.com/go-resty/resty/v2"
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/storeLink"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
|
|
"log"
|
|
"math/rand"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"sort"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type IImageInference interface {
|
|
AppendRoute(urls []*inference.InferUrl) error
|
|
GetAiType() string
|
|
}
|
|
|
|
type ImageFile struct {
|
|
ImageResult *types.ImageResult
|
|
File multipart.File
|
|
}
|
|
|
|
type FilteredCluster struct {
|
|
urls []*inference.InferUrl
|
|
clusterId string
|
|
clusterName string
|
|
clusterType string
|
|
imageNum int32
|
|
}
|
|
|
|
type ImageInference struct {
|
|
inference IImageInference
|
|
files []*ImageFile
|
|
clusters []*strategy.AssignedCluster
|
|
instances []*models.AiInferDeployInstance
|
|
opt *option.InferOption
|
|
storage *database.AiStorage
|
|
inferAdapter map[string]map[string]inference.ICluster
|
|
errMap map[string]string
|
|
adapterName string
|
|
}
|
|
|
|
func New(
|
|
inference IImageInference,
|
|
files []*ImageFile,
|
|
clusters []*strategy.AssignedCluster,
|
|
instances []*models.AiInferDeployInstance,
|
|
opt *option.InferOption,
|
|
storage *database.AiStorage,
|
|
inferAdapter map[string]map[string]inference.ICluster,
|
|
adapterName string) (*ImageInference, error) {
|
|
|
|
return &ImageInference{
|
|
inference: inference,
|
|
files: files,
|
|
clusters: clusters,
|
|
instances: instances,
|
|
opt: opt,
|
|
storage: storage,
|
|
inferAdapter: inferAdapter,
|
|
adapterName: adapterName,
|
|
errMap: make(map[string]string),
|
|
}, nil
|
|
}
|
|
|
|
func (i *ImageInference) CreateTask() (int64, error) {
|
|
id, err := i.saveTask()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
err = i.saveAiTask(id)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (i *ImageInference) InferTask(id int64) error {
|
|
clusters, err := i.filterClusters()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
aiTaskList, err := i.storage.GetAiTaskListById(id)
|
|
if err != nil || len(aiTaskList) == 0 {
|
|
return err
|
|
}
|
|
|
|
err = i.updateStatus(aiTaskList, clusters)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
results, err := i.inferImages(clusters)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = i.saveAiSubTasks(id, aiTaskList, clusters, results)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *ImageInference) saveTask() (int64, error) {
|
|
var synergystatus int64
|
|
if len(i.clusters) > 1 {
|
|
synergystatus = 1
|
|
}
|
|
|
|
strategyCode, err := i.storage.GetStrategyCode(i.opt.Strategy)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
id, err := i.storage.SaveTask(i.opt.TaskName, strategyCode, synergystatus, i.inference.GetAiType())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
i.storage.AddNoticeInfo("", "", "", "", i.opt.TaskName, "create", "任务创建中")
|
|
|
|
return id, nil
|
|
}
|
|
|
|
func (i *ImageInference) saveAiTask(id int64) error {
|
|
for _, c := range i.clusters {
|
|
clusterName, _ := i.storage.GetClusterNameById(c.ClusterId)
|
|
i.opt.Replica = c.Replicas
|
|
err := i.storage.SaveAiTask(id, i.opt, i.adapterName, c.ClusterId, clusterName, "", constants.Saved, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *ImageInference) filterClustersTemp() ([]*FilteredCluster, error) {
|
|
var wg sync.WaitGroup
|
|
var ch = make(chan *FilteredCluster, len(i.clusters))
|
|
var cs []*FilteredCluster
|
|
var mutex sync.Mutex
|
|
|
|
inferMap := i.inferAdapter[i.opt.AdapterId]
|
|
|
|
for _, cluster := range i.clusters {
|
|
wg.Add(1)
|
|
c := cluster
|
|
go func() {
|
|
r := http.Request{}
|
|
clusterInferUrl, err := inferMap[c.ClusterId].GetClusterInferUrl(r.Context(), i.opt)
|
|
if err != nil {
|
|
mutex.Lock()
|
|
i.errMap[c.ClusterId] = err.Error()
|
|
mutex.Unlock()
|
|
wg.Done()
|
|
return
|
|
}
|
|
|
|
i.inference.AppendRoute(clusterInferUrl.InferUrls)
|
|
|
|
var f FilteredCluster
|
|
f.urls = clusterInferUrl.InferUrls
|
|
f.clusterId = c.ClusterId
|
|
f.clusterName = clusterInferUrl.ClusterName
|
|
f.clusterType = clusterInferUrl.ClusterType
|
|
f.imageNum = c.Replicas
|
|
|
|
ch <- &f
|
|
wg.Done()
|
|
return
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
close(ch)
|
|
|
|
for s := range ch {
|
|
cs = append(cs, s)
|
|
}
|
|
return cs, nil
|
|
}
|
|
|
|
func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) {
|
|
var cs []*FilteredCluster
|
|
for _, cluster := range i.clusters {
|
|
var inferurls []*inference.InferUrl
|
|
var clustertype string
|
|
var clusterName string
|
|
|
|
for _, instance := range i.instances {
|
|
clusterId := strconv.FormatInt(instance.ClusterId, 10)
|
|
adapterId := strconv.FormatInt(instance.AdapterId, 10)
|
|
|
|
if cluster.ClusterId == clusterId {
|
|
r := http.Request{}
|
|
deployInstance, err := i.inferAdapter[adapterId][clusterId].GetInferDeployInstance(r.Context(), instance.InstanceId)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
var url inference.InferUrl
|
|
url.Url = deployInstance.InferUrl
|
|
url.Card = deployInstance.InferCard
|
|
inferurls = append(inferurls, &url)
|
|
|
|
clustertype = deployInstance.ClusterType
|
|
clusterName = deployInstance.ClusterName
|
|
}
|
|
}
|
|
if len(inferurls) == 0 {
|
|
continue
|
|
}
|
|
|
|
i.inference.AppendRoute(inferurls)
|
|
|
|
var f FilteredCluster
|
|
f.urls = inferurls
|
|
f.clusterId = cluster.ClusterId
|
|
f.clusterName = clusterName
|
|
f.clusterType = clustertype
|
|
f.imageNum = cluster.Replicas
|
|
cs = append(cs, &f)
|
|
}
|
|
return cs, nil
|
|
}
|
|
|
|
func (i *ImageInference) inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) {
|
|
var wg sync.WaitGroup
|
|
var ch = make(chan *types.ImageResult, len(i.files))
|
|
var results []*types.ImageResult
|
|
limit := make(chan bool, 5)
|
|
|
|
var imageNumIdx int32 = 0
|
|
var imageNumIdxEnd int32 = 0
|
|
for _, c := range cs {
|
|
new_images := make([]*ImageFile, len(i.files))
|
|
copy(new_images, i.files)
|
|
|
|
imageNumIdxEnd = imageNumIdxEnd + c.imageNum
|
|
new_images = new_images[imageNumIdx:imageNumIdxEnd]
|
|
imageNumIdx = imageNumIdx + c.imageNum
|
|
|
|
wg.Add(len(new_images))
|
|
go i.sendInferReq(new_images, c, &wg, ch, limit)
|
|
}
|
|
wg.Wait()
|
|
close(ch)
|
|
|
|
for s := range ch {
|
|
results = append(results, s)
|
|
}
|
|
|
|
sort.Slice(results, func(p, q int) bool {
|
|
return results[p].ClusterName < results[q].ClusterName
|
|
})
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (i *ImageInference) updateStatus(aiTaskList []*models.TaskAi, cs []*FilteredCluster) error {
|
|
|
|
//no cluster available
|
|
if len(cs) == 0 {
|
|
for _, t := range aiTaskList {
|
|
t.Status = constants.Failed
|
|
t.EndTime = time.Now().Format(time.RFC3339)
|
|
if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok {
|
|
t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))]
|
|
}
|
|
err := i.storage.UpdateAiTask(t)
|
|
if err != nil {
|
|
logx.Errorf(err.Error())
|
|
}
|
|
}
|
|
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败")
|
|
return errors.New("available clusters' empty, image infer task failed")
|
|
}
|
|
|
|
//change cluster status
|
|
if len(i.clusters) != len(cs) {
|
|
var failedclusters []*strategy.AssignedCluster
|
|
var runningclusters []*strategy.AssignedCluster
|
|
for _, cluster := range i.clusters {
|
|
if contains(cs, cluster.ClusterId) {
|
|
var ac *strategy.AssignedCluster
|
|
ac = cluster
|
|
runningclusters = append(runningclusters, ac)
|
|
} else {
|
|
var ac *strategy.AssignedCluster
|
|
ac = cluster
|
|
failedclusters = append(failedclusters, ac)
|
|
}
|
|
}
|
|
|
|
// update failed cluster status
|
|
for _, ac := range failedclusters {
|
|
for _, t := range aiTaskList {
|
|
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
|
|
t.Status = constants.Failed
|
|
t.EndTime = time.Now().Format(time.RFC3339)
|
|
if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok {
|
|
t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))]
|
|
}
|
|
err := i.storage.UpdateAiTask(t)
|
|
if err != nil {
|
|
logx.Errorf(err.Error())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// update running cluster status
|
|
for _, ac := range runningclusters {
|
|
for _, t := range aiTaskList {
|
|
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
|
|
t.Status = constants.Running
|
|
err := i.storage.UpdateAiTask(t)
|
|
if err != nil {
|
|
logx.Errorf(err.Error())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败")
|
|
} else {
|
|
for _, t := range aiTaskList {
|
|
t.Status = constants.Running
|
|
err := i.storage.UpdateAiTask(t)
|
|
if err != nil {
|
|
logx.Errorf(err.Error())
|
|
}
|
|
}
|
|
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "running", "任务运行中")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (i *ImageInference) sendInferReq(images []*ImageFile, cluster *FilteredCluster, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) {
|
|
for _, image := range images {
|
|
limit <- true
|
|
go func(t *ImageFile, c *FilteredCluster) {
|
|
if len(c.urls) == 1 {
|
|
r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterId, c.clusterType, i.inferAdapter, i.opt.AdapterId)
|
|
if err != nil {
|
|
t.ImageResult.ImageResult = err.Error()
|
|
t.ImageResult.ClusterId = c.clusterId
|
|
t.ImageResult.ClusterName = c.clusterName
|
|
t.ImageResult.Card = c.urls[0].Card
|
|
ch <- t.ImageResult
|
|
wg.Done()
|
|
<-limit
|
|
return
|
|
}
|
|
t.ImageResult.ImageResult = r
|
|
t.ImageResult.ClusterId = c.clusterId
|
|
t.ImageResult.ClusterName = c.clusterName
|
|
t.ImageResult.Card = c.urls[0].Card
|
|
|
|
ch <- t.ImageResult
|
|
wg.Done()
|
|
<-limit
|
|
return
|
|
} else {
|
|
idx := rand.Intn(len(c.urls))
|
|
r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterId, c.clusterType, i.inferAdapter, i.opt.AdapterId)
|
|
if err != nil {
|
|
t.ImageResult.ImageResult = err.Error()
|
|
t.ImageResult.ClusterId = c.clusterId
|
|
t.ImageResult.ClusterName = c.clusterName
|
|
t.ImageResult.Card = c.urls[idx].Card
|
|
ch <- t.ImageResult
|
|
wg.Done()
|
|
<-limit
|
|
return
|
|
}
|
|
t.ImageResult.ImageResult = r
|
|
t.ImageResult.ClusterId = c.clusterId
|
|
t.ImageResult.ClusterName = c.clusterName
|
|
t.ImageResult.Card = c.urls[idx].Card
|
|
|
|
ch <- t.ImageResult
|
|
wg.Done()
|
|
<-limit
|
|
return
|
|
}
|
|
}(image, cluster)
|
|
}
|
|
}
|
|
|
|
func (i *ImageInference) saveAiSubTasks(id int64, aiTaskList []*models.TaskAi, cs []*FilteredCluster, results []*types.ImageResult) error {
|
|
//save ai sub tasks
|
|
for _, r := range results {
|
|
for _, task := range aiTaskList {
|
|
if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
|
|
taskAiSub := models.TaskAiSub{
|
|
TaskId: id,
|
|
TaskName: task.Name,
|
|
TaskAiId: task.TaskId,
|
|
TaskAiName: task.Name,
|
|
ImageName: r.ImageName,
|
|
Result: r.ImageResult,
|
|
Card: r.Card,
|
|
ClusterId: task.ClusterId,
|
|
ClusterName: r.ClusterName,
|
|
}
|
|
err := i.storage.SaveAiTaskImageSubTask(&taskAiSub)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// update succeeded cluster status
|
|
var successStatusCount int
|
|
for _, c := range cs {
|
|
for _, t := range aiTaskList {
|
|
if c.clusterId == strconv.Itoa(int(t.ClusterId)) {
|
|
t.Status = constants.Completed
|
|
t.EndTime = time.Now().Format(time.RFC3339)
|
|
err := i.storage.UpdateAiTask(t)
|
|
if err != nil {
|
|
logx.Errorf(err.Error())
|
|
}
|
|
successStatusCount++
|
|
} else {
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(cs) == successStatusCount {
|
|
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "completed", "任务完成")
|
|
} else {
|
|
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getInferResult(url string, file multipart.File, fileName string, clusterId string, clusterType string, inferAdapter map[string]map[string]inference.ICluster, adapterId string) (string, error) {
|
|
adapter, found := inferAdapter[adapterId]
|
|
if !found {
|
|
return "", errors.New("adapterId not found")
|
|
}
|
|
iCluster, found := adapter[clusterId]
|
|
if !found {
|
|
return "", errors.New("clusterId not found")
|
|
}
|
|
|
|
switch clusterType {
|
|
case storeLink.TYPE_OCTOPUS:
|
|
r := http.Request{}
|
|
result, err := iCluster.GetInferResult(r.Context(), url, file, fileName)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return result, nil
|
|
case storeLink.TYPE_MODELARTS:
|
|
r, err := getInferResultModelarts(url, file, fileName)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return r, nil
|
|
default:
|
|
var res Res
|
|
req := GetRestyRequest(20)
|
|
_, err := req.
|
|
SetFileReader("file", fileName, file).
|
|
SetResult(&res).
|
|
Post(url)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return res.Result, nil
|
|
}
|
|
}
|
|
|
|
func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) {
|
|
var res Res
|
|
/* req := GetRestyRequest(20)
|
|
_, err := req.
|
|
SetFileReader("file", fileName, file).
|
|
SetHeaders(map[string]string{
|
|
"ak": "UNEHPHO4Z7YSNPKRXFE4",
|
|
"sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9",
|
|
}).
|
|
SetResult(&res).
|
|
Post(url)
|
|
if err != nil {
|
|
return "", err
|
|
}*/
|
|
body, err := utils.SendRequest("POST", url, file, fileName)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
errjson := json.Unmarshal([]byte(body), &res)
|
|
if errjson != nil {
|
|
log.Fatalf("Error parsing JSON: %s", errjson)
|
|
}
|
|
return res.Result, nil
|
|
}
|
|
|
|
func GetRestyRequest(timeoutSeconds int64) *resty.Request {
|
|
client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second)
|
|
request := client.R()
|
|
return request
|
|
}
|
|
|
|
type Res struct {
|
|
Result string `json:"result"`
|
|
}
|
|
|
|
func contains(cs []*FilteredCluster, e string) bool {
|
|
for _, c := range cs {
|
|
if c.clusterId == e {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|