Former-commit-id: 352ba59abdbef6f56251beeac4d0e991a648cc82
This commit is contained in:
zhangwei 2024-06-22 16:55:22 +08:00
commit 19ca559051
10 changed files with 517 additions and 4 deletions

View File

@ -44,4 +44,15 @@ type (
ImageResult string `json:"imageResult"`
}
InferenceTaskDetailReq{
aiTaskId int64 `json:"aiTaskId"`
}
InferenceTaskDetailResp{
imageName string `json:"imageName"`
result string `json:"result"`
card string `json:"card"`
clusterName string `json:"clusterName"`
}
)

View File

@ -1,8 +1,13 @@
package inference
import "C"
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/JCCE-nudt/apigw-go-sdk/core"
"github.com/go-resty/resty/v2"
"github.com/zeromicro/go-zero/core/logx"
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
@ -12,6 +17,9 @@ import (
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
"io"
"k8s.io/apimachinery/pkg/util/json"
"log"
"math/rand"
"mime/multipart"
"net/http"
@ -284,6 +292,26 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
results = append(results, s)
}
//save ai sub tasks
for _, r := range results {
for _, task := range aiTaskList {
if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
taskAiSub := &models.TaskAiSub{
Id: task.Id,
ImageName: r.ImageName,
Result: r.ImageResult,
Card: r.Card,
ClusterId: task.ClusterId,
ClusterName: r.ClusterName,
}
tx := svcCtx.DbEngin.Save(&taskAiSub)
if tx.Error != nil {
logx.Errorf(err.Error())
}
}
}
}
sort.Slice(results, func(p, q int) bool {
return results[p].ClusterName < results[q].ClusterName
})
@ -334,9 +362,10 @@ func sendInferReq(images []struct {
imageNum int32
}) {
if len(c.urls) == 1 {
r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName)
r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName, c.clusterName)
if err != nil {
t.imageResult.ImageResult = err.Error()
t.imageResult.ClusterId = c.clusterId
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[0].Card
ch <- t.imageResult
@ -344,6 +373,7 @@ func sendInferReq(images []struct {
return
}
t.imageResult.ImageResult = r
t.imageResult.ClusterId = c.clusterId
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[0].Card
@ -352,9 +382,10 @@ func sendInferReq(images []struct {
return
} else {
idx := rand.Intn(len(c.urls))
r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName)
r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName, c.clusterName)
if err != nil {
t.imageResult.ImageResult = err.Error()
t.imageResult.ClusterId = c.clusterId
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[idx].Card
ch <- t.imageResult
@ -362,6 +393,7 @@ func sendInferReq(images []struct {
return
}
t.imageResult.ImageResult = r
t.imageResult.ClusterId = c.clusterId
t.imageResult.ClusterName = c.clusterName
t.imageResult.Card = c.urls[idx].Card
@ -373,20 +405,113 @@ func sendInferReq(images []struct {
}
}
func getInferResult(url string, file multipart.File, fileName string) (string, error) {
func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) {
if clusterName == "鹏城云脑II-modelarts" {
r, err := getInferResultModelarts(url, file, fileName)
if err != nil {
return "", err
}
return r, nil
}
var res Res
req := GetRestyRequest(10)
_, err := req.
SetFileReader("file", fileName, file).
SetResult(&res).
Post(url)
if err != nil {
return "", err
}
return res.Result, nil
}
func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) {
var res Res
body, err := SendRequest("POST", url, file, fileName)
if err != nil {
return "", err
}
errjson := json.Unmarshal([]byte(body), &res)
if errjson != nil {
log.Fatalf("Error parsing JSON: %s", errjson)
}
return res.Result, nil
}
// SignClient AK/SK签名认证
func SignClient(r *http.Request, writer *multipart.Writer) (*http.Client, error) {
r.Header.Add("content-type", "application/json;charset=UTF-8")
r.Header.Add("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52")
r.Header.Add("x-stage", "RELEASE")
r.Header.Add("x-sdk-content-sha256", "UNSIGNED-PAYLOAD")
r.Header.Set("Content-Type", writer.FormDataContentType())
s := core.Signer{
Key: "UNEHPHO4Z7YSNPKRXFE4",
Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9",
}
err := s.Sign(r)
if err != nil {
return nil, err
}
//设置client信任所有证书
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{
Transport: tr,
}
return client, nil
}
func SendRequest(method, url string, file multipart.File, fileName string) (string, error) {
/*body := &bytes.Buffer{}
writer := multipart.NewWriter(body)*/
// 创建一个新的缓冲区以写入multipart表单
var body bytes.Buffer
// 创建一个新的multipart writer
writer := multipart.NewWriter(&body)
// 创建一个用于写入文件的表单字段
part, err := writer.CreateFormFile("file", fileName) // "file"是表单的字段名,第二个参数是文件名
if err != nil {
fmt.Println("Error creating form file:", err)
}
// 将文件的内容拷贝到multipart writer中
_, err = io.Copy(part, file)
if err != nil {
fmt.Println("Error copying file data:", err)
}
err = writer.Close()
if err != nil {
fmt.Println("Error closing multipart writer:", err)
}
request, err := http.NewRequest(method, url, &body)
if err != nil {
fmt.Println("Error creating new request:", err)
//return nil, err
}
signedR, err := SignClient(request, writer)
if err != nil {
fmt.Println("Error signing request:", err)
//return nil, err
}
res, err := signedR.Do(request)
if err != nil {
fmt.Println("Error sending request:", err)
//return nil, err
}
defer res.Body.Close()
Resbody, err := io.ReadAll(res.Body)
if err != nil {
fmt.Println("Error reading response body:", err)
//return nil, err
}
return string(Resbody), nil
}
func GetRestyRequest(timeoutSeconds int64) *resty.Request {
client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second)
request := client.R()

View File

@ -5910,3 +5910,14 @@ type ImageResult struct {
Card string `json:"card"`
ImageResult string `json:"imageResult"`
}
type InferenceTaskDetailReq struct {
AiTaskId int64 `json:"aiTaskId"`
}
type InferenceTaskDetailResp struct {
ImageName string `json:"imageName"`
Result string `json:"result"`
Card string `json:"card"`
ClusterName string `json:"clusterName"`
}

1
go.mod
View File

@ -7,6 +7,7 @@ toolchain go1.22.4
retract v0.1.20-0.20240319015239-6ae13da05255
require (
github.com/JCCE-nudt/apigw-go-sdk v0.0.0-20230525025609-34159d6f2818
github.com/Masterminds/squirrel v1.5.4
github.com/bwmarrin/snowflake v0.3.0
github.com/ghodss/yaml v1.0.0

2
go.sum
View File

@ -38,6 +38,8 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
github.com/JCCE-nudt/apigw-go-sdk v0.0.0-20230525025609-34159d6f2818 h1:QLulhUyxPDs9FFieVZwmKAnUBLeRDhsVNehotAAL/FE=
github.com/JCCE-nudt/apigw-go-sdk v0.0.0-20230525025609-34159d6f2818/go.mod h1:j+am5/1URgsvyhOAyURFR9vH3malaW7Tq6d33OyPsnM=
github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM=
github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
github.com/alecthomas/kingpin/v2 v2.4.0 h1:f48lwail6p8zpO1bC4TxtqACaGqHYA22qkHjHpqDjYY=

View File

@ -0,0 +1,29 @@
package models
import "github.com/zeromicro/go-zero/core/stores/sqlx"
var _ TaskAiSubModel = (*customTaskAiSubModel)(nil)
type (
// TaskAiSubModel is an interface to be customized, add more methods here,
// and implement the added methods in customTaskAiSubModel.
TaskAiSubModel interface {
taskAiSubModel
withSession(session sqlx.Session) TaskAiSubModel
}
customTaskAiSubModel struct {
*defaultTaskAiSubModel
}
)
// NewTaskAiSubModel returns a model for the database table.
func NewTaskAiSubModel(conn sqlx.SqlConn) TaskAiSubModel {
return &customTaskAiSubModel{
defaultTaskAiSubModel: newTaskAiSubModel(conn),
}
}
func (m *customTaskAiSubModel) withSession(session sqlx.Session) TaskAiSubModel {
return NewTaskAiSubModel(sqlx.NewSqlConnFromSession(session))
}

View File

@ -0,0 +1,88 @@
// Code generated by goctl. DO NOT EDIT.
package models
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/zeromicro/go-zero/core/stores/builder"
"github.com/zeromicro/go-zero/core/stores/sqlc"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/stringx"
)
var (
taskAiSubFieldNames = builder.RawFieldNames(&TaskAiSub{})
taskAiSubRows = strings.Join(taskAiSubFieldNames, ",")
taskAiSubRowsExpectAutoSet = strings.Join(stringx.Remove(taskAiSubFieldNames, "`id`", "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), ",")
taskAiSubRowsWithPlaceHolder = strings.Join(stringx.Remove(taskAiSubFieldNames, "`id`", "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), "=?,") + "=?"
)
type (
taskAiSubModel interface {
Insert(ctx context.Context, data *TaskAiSub) (sql.Result, error)
FindOne(ctx context.Context, id int64) (*TaskAiSub, error)
Update(ctx context.Context, data *TaskAiSub) error
Delete(ctx context.Context, id int64) error
}
defaultTaskAiSubModel struct {
conn sqlx.SqlConn
table string
}
TaskAiSub struct {
Id int64 `db:"id"` // id
ImageName string `db:"image_name"` // 图片名称
Result string `db:"result"` // 识别结果
Card string `db:"card"` // 加速卡
ClusterId int64 `db:"cluster_id"` // 集群id
ClusterName string `db:"cluster_name"` // 集群名称
}
)
func newTaskAiSubModel(conn sqlx.SqlConn) *defaultTaskAiSubModel {
return &defaultTaskAiSubModel{
conn: conn,
table: "`task_ai_sub`",
}
}
func (m *defaultTaskAiSubModel) Delete(ctx context.Context, id int64) error {
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
_, err := m.conn.ExecCtx(ctx, query, id)
return err
}
func (m *defaultTaskAiSubModel) FindOne(ctx context.Context, id int64) (*TaskAiSub, error) {
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", taskAiSubRows, m.table)
var resp TaskAiSub
err := m.conn.QueryRowCtx(ctx, &resp, query, id)
switch err {
case nil:
return &resp, nil
case sqlc.ErrNotFound:
return nil, ErrNotFound
default:
return nil, err
}
}
func (m *defaultTaskAiSubModel) Insert(ctx context.Context, data *TaskAiSub) (sql.Result, error) {
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?)", m.table, taskAiSubRowsExpectAutoSet)
ret, err := m.conn.ExecCtx(ctx, query, data.ImageName, data.Result, data.Card, data.ClusterId, data.ClusterName)
return ret, err
}
func (m *defaultTaskAiSubModel) Update(ctx context.Context, data *TaskAiSub) error {
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, taskAiSubRowsWithPlaceHolder)
_, err := m.conn.ExecCtx(ctx, query, data.ImageName, data.Result, data.Card, data.ClusterId, data.ClusterName, data.Id)
return err
}
func (m *defaultTaskAiSubModel) tableName() string {
return m.table
}

View File

@ -0,0 +1,42 @@
// based on https://github.com/golang/go/blob/master/src/net/url/url.go
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package core
func shouldEscape(c byte) bool {
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' || c == '_' || c == '-' || c == '~' || c == '.' {
return false
}
return true
}
func escape(s string) string {
hexCount := 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c) {
hexCount++
}
}
if hexCount == 0 {
return s
}
t := make([]byte, len(s)+2*hexCount)
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case shouldEscape(c):
t[j] = '%'
t[j+1] = "0123456789ABCDEF"[c>>4]
t[j+2] = "0123456789ABCDEF"[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}

201
vendor/github.com/JCCE-nudt/apigw-go-sdk/core/signer.go generated vendored Normal file
View File

@ -0,0 +1,201 @@
// HWS API Gateway Signature
// based on https://github.com/datastream/aws/blob/master/signv4.go
// Copyright (c) 2014, Xianjie
package core
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"fmt"
"io/ioutil"
"net/http"
"sort"
"strings"
"time"
)
const (
DateFormat = "20060102T150405Z"
SignAlgorithm = "SDK-HMAC-SHA256"
HeaderXDateTime = "X-Sdk-Date"
HeaderXHost = "host"
HeaderXAuthorization = "Authorization"
HeaderXContentSha256 = "X-Sdk-Content-Sha256"
)
func hmacsha256(keyByte []byte, dataStr string) ([]byte, error) {
hm := hmac.New(sha256.New, []byte(keyByte))
if _, err := hm.Write([]byte(dataStr)); err != nil {
return nil, err
}
return hm.Sum(nil), nil
}
// Build a CanonicalRequest from a regular request string
func CanonicalRequest(request *http.Request, signedHeaders []string) (string, error) {
var hexencode string
var err error
if hex := request.Header.Get(HeaderXContentSha256); hex != "" {
hexencode = hex
} else {
bodyData, err := RequestPayload(request)
if err != nil {
return "", err
}
hexencode, err = HexEncodeSHA256Hash(bodyData)
if err != nil {
return "", err
}
}
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", request.Method, CanonicalURI(request), CanonicalQueryString(request), CanonicalHeaders(request, signedHeaders), strings.Join(signedHeaders, ";"), hexencode), err
}
// CanonicalURI returns request uri
func CanonicalURI(request *http.Request) string {
pattens := strings.Split(request.URL.Path, "/")
var uriSlice []string
for _, v := range pattens {
uriSlice = append(uriSlice, escape(v))
}
urlpath := strings.Join(uriSlice, "/")
if len(urlpath) == 0 || urlpath[len(urlpath)-1] != '/' {
urlpath = urlpath + "/"
}
return urlpath
}
// CanonicalQueryString
func CanonicalQueryString(request *http.Request) string {
var keys []string
queryMap := request.URL.Query()
for key := range queryMap {
keys = append(keys, key)
}
sort.Strings(keys)
var query []string
for _, key := range keys {
k := escape(key)
sort.Strings(queryMap[key])
for _, v := range queryMap[key] {
kv := fmt.Sprintf("%s=%s", k, escape(v))
query = append(query, kv)
}
}
queryStr := strings.Join(query, "&")
request.URL.RawQuery = queryStr
return queryStr
}
// CanonicalHeaders
func CanonicalHeaders(request *http.Request, signerHeaders []string) string {
var canonicalHeaders []string
header := make(map[string][]string)
for k, v := range request.Header {
header[strings.ToLower(k)] = v
}
for _, key := range signerHeaders {
value := header[key]
if strings.EqualFold(key, HeaderXHost) {
value = []string{request.Host}
}
sort.Strings(value)
for _, v := range value {
canonicalHeaders = append(canonicalHeaders, key+":"+strings.TrimSpace(v))
}
}
return fmt.Sprintf("%s\n", strings.Join(canonicalHeaders, "\n"))
}
// SignedHeaders
func SignedHeaders(r *http.Request) []string {
var signedHeaders []string
for key := range r.Header {
signedHeaders = append(signedHeaders, strings.ToLower(key))
}
sort.Strings(signedHeaders)
return signedHeaders
}
// RequestPayload
func RequestPayload(request *http.Request) ([]byte, error) {
if request.Body == nil {
return []byte(""), nil
}
bodyByte, err := ioutil.ReadAll(request.Body)
if err != nil {
return []byte(""), err
}
request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyByte))
return bodyByte, err
}
// Create a "String to Sign".
func StringToSign(canonicalRequest string, t time.Time) (string, error) {
hashStruct := sha256.New()
_, err := hashStruct.Write([]byte(canonicalRequest))
if err != nil {
return "", err
}
return fmt.Sprintf("%s\n%s\n%x",
SignAlgorithm, t.UTC().Format(DateFormat), hashStruct.Sum(nil)), nil
}
// Create the HWS Signature.
func SignStringToSign(stringToSign string, signingKey []byte) (string, error) {
hmsha, err := hmacsha256(signingKey, stringToSign)
return fmt.Sprintf("%x", hmsha), err
}
// HexEncodeSHA256Hash returns hexcode of sha256
func HexEncodeSHA256Hash(body []byte) (string, error) {
hashStruct := sha256.New()
if len(body) == 0 {
body = []byte("")
}
_, err := hashStruct.Write(body)
return fmt.Sprintf("%x", hashStruct.Sum(nil)), err
}
// Get the finalized value for the "Authorization" header. The signature parameter is the output from SignStringToSign
func AuthHeaderValue(signatureStr, accessKeyStr string, signedHeaders []string) string {
return fmt.Sprintf("%s Access=%s, SignedHeaders=%s, Signature=%s", SignAlgorithm, accessKeyStr, strings.Join(signedHeaders, ";"), signatureStr)
}
// Signature HWS meta
type Signer struct {
Key string
Secret string
}
// SignRequest set Authorization header
func (s *Signer) Sign(request *http.Request) error {
var t time.Time
var err error
var date string
if date = request.Header.Get(HeaderXDateTime); date != "" {
t, err = time.Parse(DateFormat, date)
}
if err != nil || date == "" {
t = time.Now()
request.Header.Set(HeaderXDateTime, t.UTC().Format(DateFormat))
}
signedHeaders := SignedHeaders(request)
canonicalRequest, err := CanonicalRequest(request, signedHeaders)
if err != nil {
return err
}
stringToSignStr, err := StringToSign(canonicalRequest, t)
if err != nil {
return err
}
signatureStr, err := SignStringToSign(stringToSignStr, []byte(s.Secret))
if err != nil {
return err
}
authValueStr := AuthHeaderValue(signatureStr, s.Key, signedHeaders)
request.Header.Set(HeaderXAuthorization, authValueStr)
return nil
}

3
vendor/modules.txt vendored
View File

@ -2,6 +2,9 @@
## explicit; go 1.20
filippo.io/edwards25519
filippo.io/edwards25519/field
# github.com/JCCE-nudt/apigw-go-sdk v0.0.0-20230525025609-34159d6f2818
## explicit; go 1.17
github.com/JCCE-nudt/apigw-go-sdk/core
# github.com/Masterminds/squirrel v1.5.4
## explicit; go 1.14
github.com/Masterminds/squirrel