diff --git a/api/internal/logic/inference/imageinferencelogic.go b/api/internal/logic/inference/imageinferencelogic.go index 30567d8b..0eac524d 100644 --- a/api/internal/logic/inference/imageinferencelogic.go +++ b/api/internal/logic/inference/imageinferencelogic.go @@ -1,8 +1,13 @@ package inference +import "C" import ( + "APIGW-go-sdk/core" + "bytes" "context" + "crypto/tls" "errors" + "fmt" "github.com/go-resty/resty/v2" "github.com/zeromicro/go-zero/core/logx" "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option" @@ -12,6 +17,9 @@ import ( "gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants" "gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models" + "io" + "k8s.io/apimachinery/pkg/util/json" + "log" "math/rand" "mime/multipart" "net/http" @@ -334,7 +342,7 @@ func sendInferReq(images []struct { imageNum int32 }) { if len(c.urls) == 1 { - r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName) + r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName, c.clusterName) if err != nil { t.imageResult.ImageResult = err.Error() t.imageResult.ClusterName = c.clusterName @@ -352,7 +360,7 @@ func sendInferReq(images []struct { return } else { idx := rand.Intn(len(c.urls)) - r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName) + r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName, c.clusterName) if err != nil { t.imageResult.ImageResult = err.Error() t.imageResult.ClusterName = c.clusterName @@ -373,20 +381,113 @@ func sendInferReq(images []struct { } } -func getInferResult(url string, file multipart.File, fileName string) (string, error) { +func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) { + if clusterName == "鹏城云脑II-modelarts" { + r, err := getInferResultModelarts(url, file, fileName) + if err != nil { + return "", err + } + return r, nil + } var res Res req := GetRestyRequest(10) _, err := req. SetFileReader("file", fileName, file). SetResult(&res). Post(url) - if err != nil { return "", err } return res.Result, nil } +func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) { + var res Res + body, err := SendRequest("POST", url, file, fileName) + if err != nil { + return "", err + } + errjson := json.Unmarshal([]byte(body), &res) + if errjson != nil { + log.Fatalf("Error parsing JSON: %s", errjson) + } + + return res.Result, nil +} + +// SignClient AK/SK签名认证 +func SignClient(r *http.Request, writer *multipart.Writer) (*http.Client, error) { + r.Header.Add("content-type", "application/json;charset=UTF-8") + r.Header.Add("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52") + r.Header.Add("x-stage", "RELEASE") + r.Header.Add("x-sdk-content-sha256", "UNSIGNED-PAYLOAD") + r.Header.Set("Content-Type", writer.FormDataContentType()) + s := core.Signer{ + Key: "UNEHPHO4Z7YSNPKRXFE4", + Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9", + } + err := s.Sign(r) + if err != nil { + return nil, err + } + + //设置client信任所有证书 + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + client := &http.Client{ + Transport: tr, + } + return client, nil +} + +func SendRequest(method, url string, file multipart.File, fileName string) (string, error) { + /*body := &bytes.Buffer{} + writer := multipart.NewWriter(body)*/ + // 创建一个新的缓冲区以写入multipart表单 + var body bytes.Buffer + // 创建一个新的multipart writer + writer := multipart.NewWriter(&body) + // 创建一个用于写入文件的表单字段 + part, err := writer.CreateFormFile("file", fileName) // "file"是表单的字段名,第二个参数是文件名 + if err != nil { + fmt.Println("Error creating form file:", err) + } + // 将文件的内容拷贝到multipart writer中 + _, err = io.Copy(part, file) + if err != nil { + fmt.Println("Error copying file data:", err) + + } + err = writer.Close() + if err != nil { + fmt.Println("Error closing multipart writer:", err) + } + request, err := http.NewRequest(method, "https://modelarts-inference.cloudbrain2.pcl.ac.cn/v1/infers/fb0f011f-3e74-4396-ab81-20d65525d22b/image", &body) + if err != nil { + fmt.Println("Error creating new request:", err) + //return nil, err + } + signedR, err := SignClient(request, writer) + if err != nil { + fmt.Println("Error signing request:", err) + //return nil, err + } + + res, err := signedR.Do(request) + if err != nil { + fmt.Println("Error sending request:", err) + //return nil, err + } + defer res.Body.Close() + Resbody, err := io.ReadAll(res.Body) + if err != nil { + fmt.Println("Error reading response body:", err) + //return nil, err + } + return string(Resbody), nil +} + func GetRestyRequest(timeoutSeconds int64) *resty.Request { client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second) request := client.R()