pcm-coordinator/internal/scheduler/service/inference/imageInference/imageInference.go

534 lines
14 KiB
Go

package imageInference
import (
"encoding/json"
"errors"
"github.com/go-resty/resty/v2"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/database"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/schedulers/option"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/service/inference"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/scheduler/strategy"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/storeLink"
"gitlink.org.cn/JointCloud/pcm-coordinator/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/utils"
"log"
"math/rand"
"mime/multipart"
"net/http"
"sort"
"strconv"
"sync"
"time"
)
type IImageInference interface {
AppendRoute(urls []*inference.InferUrl) error
GetAiType() string
}
type ImageFile struct {
ImageResult *types.ImageResult
File multipart.File
}
type FilteredCluster struct {
urls []*inference.InferUrl
clusterId string
clusterName string
clusterType string
imageNum int32
}
type ImageInference struct {
inference IImageInference
files []*ImageFile
clusters []*strategy.AssignedCluster
instances []*models.AiInferDeployInstance
opt *option.InferOption
storage *database.AiStorage
inferAdapter map[string]map[string]inference.ICluster
errMap map[string]string
adapterName string
}
func New(
inference IImageInference,
files []*ImageFile,
clusters []*strategy.AssignedCluster,
instances []*models.AiInferDeployInstance,
opt *option.InferOption,
storage *database.AiStorage,
inferAdapter map[string]map[string]inference.ICluster,
adapterName string) (*ImageInference, error) {
return &ImageInference{
inference: inference,
files: files,
clusters: clusters,
instances: instances,
opt: opt,
storage: storage,
inferAdapter: inferAdapter,
adapterName: adapterName,
errMap: make(map[string]string),
}, nil
}
func (i *ImageInference) CreateTask() (int64, error) {
id, err := i.saveTask()
if err != nil {
return 0, err
}
err = i.saveAiTask(id)
if err != nil {
return 0, err
}
return id, nil
}
func (i *ImageInference) InferTask(id int64) error {
clusters, err := i.filterClusters()
if err != nil {
return err
}
aiTaskList, err := i.storage.GetAiTaskListById(id)
if err != nil || len(aiTaskList) == 0 {
return err
}
err = i.updateStatus(aiTaskList, clusters)
if err != nil {
return err
}
results, err := i.inferImages(clusters)
if err != nil {
return err
}
err = i.saveAiSubTasks(id, aiTaskList, clusters, results)
if err != nil {
return err
}
return nil
}
func (i *ImageInference) saveTask() (int64, error) {
var synergystatus int64
if len(i.clusters) > 1 {
synergystatus = 1
}
strategyCode, err := i.storage.GetStrategyCode(i.opt.Strategy)
if err != nil {
return 0, err
}
id, err := i.storage.SaveTask(i.opt.TaskName, strategyCode, synergystatus, i.inference.GetAiType())
if err != nil {
return 0, err
}
i.storage.AddNoticeInfo("", "", "", "", i.opt.TaskName, "create", "任务创建中")
return id, nil
}
func (i *ImageInference) saveAiTask(id int64) error {
for _, c := range i.clusters {
clusterName, _ := i.storage.GetClusterNameById(c.ClusterId)
i.opt.Replica = c.Replicas
err := i.storage.SaveAiTask(id, i.opt, i.adapterName, c.ClusterId, clusterName, "", constants.Saved, "")
if err != nil {
return err
}
}
return nil
}
func (i *ImageInference) filterClustersTemp() ([]*FilteredCluster, error) {
var wg sync.WaitGroup
var ch = make(chan *FilteredCluster, len(i.clusters))
var cs []*FilteredCluster
var mutex sync.Mutex
inferMap := i.inferAdapter[i.opt.AdapterId]
for _, cluster := range i.clusters {
wg.Add(1)
c := cluster
go func() {
r := http.Request{}
clusterInferUrl, err := inferMap[c.ClusterId].GetClusterInferUrl(r.Context(), i.opt)
if err != nil {
mutex.Lock()
i.errMap[c.ClusterId] = err.Error()
mutex.Unlock()
wg.Done()
return
}
i.inference.AppendRoute(clusterInferUrl.InferUrls)
var f FilteredCluster
f.urls = clusterInferUrl.InferUrls
f.clusterId = c.ClusterId
f.clusterName = clusterInferUrl.ClusterName
f.clusterType = clusterInferUrl.ClusterType
f.imageNum = c.Replicas
ch <- &f
wg.Done()
return
}()
}
wg.Wait()
close(ch)
for s := range ch {
cs = append(cs, s)
}
return cs, nil
}
func (i *ImageInference) filterClusters() ([]*FilteredCluster, error) {
var cs []*FilteredCluster
for _, cluster := range i.clusters {
var inferurls []*inference.InferUrl
var clustertype string
var clusterName string
for _, instance := range i.instances {
clusterId := strconv.FormatInt(instance.ClusterId, 10)
adapterId := strconv.FormatInt(instance.AdapterId, 10)
if cluster.ClusterId == clusterId {
r := http.Request{}
deployInstance, err := i.inferAdapter[adapterId][clusterId].GetInferDeployInstance(r.Context(), instance.InstanceId)
if err != nil {
continue
}
var url inference.InferUrl
url.Url = deployInstance.InferUrl
url.Card = deployInstance.InferCard
inferurls = append(inferurls, &url)
clustertype = deployInstance.ClusterType
clusterName = deployInstance.ClusterName
}
}
if len(inferurls) == 0 {
continue
}
i.inference.AppendRoute(inferurls)
var f FilteredCluster
f.urls = inferurls
f.clusterId = cluster.ClusterId
f.clusterName = clusterName
f.clusterType = clustertype
f.imageNum = cluster.Replicas
cs = append(cs, &f)
}
return cs, nil
}
func (i *ImageInference) inferImages(cs []*FilteredCluster) ([]*types.ImageResult, error) {
var wg sync.WaitGroup
var ch = make(chan *types.ImageResult, len(i.files))
var results []*types.ImageResult
limit := make(chan bool, 5)
var imageNumIdx int32 = 0
var imageNumIdxEnd int32 = 0
for _, c := range cs {
new_images := make([]*ImageFile, len(i.files))
copy(new_images, i.files)
imageNumIdxEnd = imageNumIdxEnd + c.imageNum
new_images = new_images[imageNumIdx:imageNumIdxEnd]
imageNumIdx = imageNumIdx + c.imageNum
wg.Add(len(new_images))
go i.sendInferReq(new_images, c, &wg, ch, limit)
}
wg.Wait()
close(ch)
for s := range ch {
results = append(results, s)
}
sort.Slice(results, func(p, q int) bool {
return results[p].ClusterName < results[q].ClusterName
})
return results, nil
}
func (i *ImageInference) updateStatus(aiTaskList []*models.TaskAi, cs []*FilteredCluster) error {
//no cluster available
if len(cs) == 0 {
for _, t := range aiTaskList {
t.Status = constants.Failed
t.EndTime = time.Now().Format(time.RFC3339)
if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok {
t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))]
}
err := i.storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
}
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败")
return errors.New("available clusters' empty, image infer task failed")
}
//change cluster status
if len(i.clusters) != len(cs) {
var failedclusters []*strategy.AssignedCluster
var runningclusters []*strategy.AssignedCluster
for _, cluster := range i.clusters {
if contains(cs, cluster.ClusterId) {
var ac *strategy.AssignedCluster
ac = cluster
runningclusters = append(runningclusters, ac)
} else {
var ac *strategy.AssignedCluster
ac = cluster
failedclusters = append(failedclusters, ac)
}
}
// update failed cluster status
for _, ac := range failedclusters {
for _, t := range aiTaskList {
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Failed
t.EndTime = time.Now().Format(time.RFC3339)
if _, ok := i.errMap[strconv.Itoa(int(t.ClusterId))]; ok {
t.Msg = i.errMap[strconv.Itoa(int(t.ClusterId))]
}
err := i.storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
}
}
}
// update running cluster status
for _, ac := range runningclusters {
for _, t := range aiTaskList {
if ac.ClusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Running
err := i.storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
}
}
}
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败")
} else {
for _, t := range aiTaskList {
t.Status = constants.Running
err := i.storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
}
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "running", "任务运行中")
}
return nil
}
func (i *ImageInference) sendInferReq(images []*ImageFile, cluster *FilteredCluster, wg *sync.WaitGroup, ch chan<- *types.ImageResult, limit chan bool) {
for _, image := range images {
limit <- true
go func(t *ImageFile, c *FilteredCluster) {
if len(c.urls) == 1 {
r, err := getInferResult(c.urls[0].Url, t.File, t.ImageResult.ImageName, c.clusterId, c.clusterType, i.inferAdapter, i.opt.AdapterId)
if err != nil {
t.ImageResult.ImageResult = err.Error()
t.ImageResult.ClusterId = c.clusterId
t.ImageResult.ClusterName = c.clusterName
t.ImageResult.Card = c.urls[0].Card
ch <- t.ImageResult
wg.Done()
<-limit
return
}
t.ImageResult.ImageResult = r
t.ImageResult.ClusterId = c.clusterId
t.ImageResult.ClusterName = c.clusterName
t.ImageResult.Card = c.urls[0].Card
ch <- t.ImageResult
wg.Done()
<-limit
return
} else {
idx := rand.Intn(len(c.urls))
r, err := getInferResult(c.urls[idx].Url, t.File, t.ImageResult.ImageName, c.clusterId, c.clusterType, i.inferAdapter, i.opt.AdapterId)
if err != nil {
t.ImageResult.ImageResult = err.Error()
t.ImageResult.ClusterId = c.clusterId
t.ImageResult.ClusterName = c.clusterName
t.ImageResult.Card = c.urls[idx].Card
ch <- t.ImageResult
wg.Done()
<-limit
return
}
t.ImageResult.ImageResult = r
t.ImageResult.ClusterId = c.clusterId
t.ImageResult.ClusterName = c.clusterName
t.ImageResult.Card = c.urls[idx].Card
ch <- t.ImageResult
wg.Done()
<-limit
return
}
}(image, cluster)
}
}
func (i *ImageInference) saveAiSubTasks(id int64, aiTaskList []*models.TaskAi, cs []*FilteredCluster, results []*types.ImageResult) error {
//save ai sub tasks
for _, r := range results {
for _, task := range aiTaskList {
if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
taskAiSub := models.TaskAiSub{
TaskId: id,
TaskName: task.Name,
TaskAiId: task.TaskId,
TaskAiName: task.Name,
ImageName: r.ImageName,
Result: r.ImageResult,
Card: r.Card,
ClusterId: task.ClusterId,
ClusterName: r.ClusterName,
}
err := i.storage.SaveAiTaskImageSubTask(&taskAiSub)
if err != nil {
panic(err)
}
}
}
}
// update succeeded cluster status
var successStatusCount int
for _, c := range cs {
for _, t := range aiTaskList {
if c.clusterId == strconv.Itoa(int(t.ClusterId)) {
t.Status = constants.Completed
t.EndTime = time.Now().Format(time.RFC3339)
err := i.storage.UpdateAiTask(t)
if err != nil {
logx.Errorf(err.Error())
}
successStatusCount++
} else {
continue
}
}
}
if len(cs) == successStatusCount {
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "completed", "任务完成")
} else {
i.storage.AddNoticeInfo(i.opt.AdapterId, i.adapterName, "", "", i.opt.TaskName, "failed", "任务失败")
}
return nil
}
func getInferResult(url string, file multipart.File, fileName string, clusterId string, clusterType string, inferAdapter map[string]map[string]inference.ICluster, adapterId string) (string, error) {
adapter, found := inferAdapter[adapterId]
if !found {
return "", errors.New("adapterId not found")
}
iCluster, found := adapter[clusterId]
if !found {
return "", errors.New("clusterId not found")
}
switch clusterType {
case storeLink.TYPE_OCTOPUS:
r := http.Request{}
result, err := iCluster.GetInferResult(r.Context(), url, file, fileName)
if err != nil {
return "", err
}
return result, nil
case storeLink.TYPE_MODELARTS:
r, err := getInferResultModelarts(url, file, fileName)
if err != nil {
return "", err
}
return r, nil
default:
var res Res
req := GetRestyRequest(20)
_, err := req.
SetFileReader("file", fileName, file).
SetResult(&res).
Post(url)
if err != nil {
return "", err
}
return res.Result, nil
}
}
func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) {
var res Res
/* req := GetRestyRequest(20)
_, err := req.
SetFileReader("file", fileName, file).
SetHeaders(map[string]string{
"ak": "UNEHPHO4Z7YSNPKRXFE4",
"sk": "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9",
}).
SetResult(&res).
Post(url)
if err != nil {
return "", err
}*/
body, err := utils.SendRequest("POST", url, file, fileName)
if err != nil {
return "", err
}
errjson := json.Unmarshal([]byte(body), &res)
if errjson != nil {
log.Fatalf("Error parsing JSON: %s", errjson)
}
return res.Result, nil
}
func GetRestyRequest(timeoutSeconds int64) *resty.Request {
client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second)
request := client.R()
return request
}
type Res struct {
Result string `json:"result"`
}
func contains(cs []*FilteredCluster, e string) bool {
for _, c := range cs {
if c.clusterId == e {
return true
}
}
return false
}