Merge branch 'master' of https://gitlink.org.cn/JointCloud/pcm-coordinator
Former-commit-id: 352ba59abdbef6f56251beeac4d0e991a648cc82
This commit is contained in:
commit
19ca559051
|
@ -44,4 +44,15 @@ type (
|
||||||
ImageResult string `json:"imageResult"`
|
ImageResult string `json:"imageResult"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
InferenceTaskDetailReq{
|
||||||
|
aiTaskId int64 `json:"aiTaskId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
InferenceTaskDetailResp{
|
||||||
|
imageName string `json:"imageName"`
|
||||||
|
result string `json:"result"`
|
||||||
|
card string `json:"card"`
|
||||||
|
clusterName string `json:"clusterName"`
|
||||||
|
}
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
package inference
|
package inference
|
||||||
|
|
||||||
|
import "C"
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/JCCE-nudt/apigw-go-sdk/core"
|
||||||
"github.com/go-resty/resty/v2"
|
"github.com/go-resty/resty/v2"
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/scheduler/schedulers/option"
|
||||||
|
@ -12,6 +17,9 @@ import (
|
||||||
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/api/internal/types"
|
||||||
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/constants"
|
||||||
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
"gitlink.org.cn/JointCloud/pcm-coordinator/pkg/models"
|
||||||
|
"io"
|
||||||
|
"k8s.io/apimachinery/pkg/util/json"
|
||||||
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -284,6 +292,26 @@ func infer(opt *option.InferOption, clusters []*strategy.AssignedCluster, ts []s
|
||||||
results = append(results, s)
|
results = append(results, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//save ai sub tasks
|
||||||
|
for _, r := range results {
|
||||||
|
for _, task := range aiTaskList {
|
||||||
|
if r.ClusterId == strconv.Itoa(int(task.ClusterId)) {
|
||||||
|
taskAiSub := &models.TaskAiSub{
|
||||||
|
Id: task.Id,
|
||||||
|
ImageName: r.ImageName,
|
||||||
|
Result: r.ImageResult,
|
||||||
|
Card: r.Card,
|
||||||
|
ClusterId: task.ClusterId,
|
||||||
|
ClusterName: r.ClusterName,
|
||||||
|
}
|
||||||
|
tx := svcCtx.DbEngin.Save(&taskAiSub)
|
||||||
|
if tx.Error != nil {
|
||||||
|
logx.Errorf(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
sort.Slice(results, func(p, q int) bool {
|
sort.Slice(results, func(p, q int) bool {
|
||||||
return results[p].ClusterName < results[q].ClusterName
|
return results[p].ClusterName < results[q].ClusterName
|
||||||
})
|
})
|
||||||
|
@ -334,9 +362,10 @@ func sendInferReq(images []struct {
|
||||||
imageNum int32
|
imageNum int32
|
||||||
}) {
|
}) {
|
||||||
if len(c.urls) == 1 {
|
if len(c.urls) == 1 {
|
||||||
r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName)
|
r, err := getInferResult(c.urls[0].Url, t.file, t.imageResult.ImageName, c.clusterName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.imageResult.ImageResult = err.Error()
|
t.imageResult.ImageResult = err.Error()
|
||||||
|
t.imageResult.ClusterId = c.clusterId
|
||||||
t.imageResult.ClusterName = c.clusterName
|
t.imageResult.ClusterName = c.clusterName
|
||||||
t.imageResult.Card = c.urls[0].Card
|
t.imageResult.Card = c.urls[0].Card
|
||||||
ch <- t.imageResult
|
ch <- t.imageResult
|
||||||
|
@ -344,6 +373,7 @@ func sendInferReq(images []struct {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.imageResult.ImageResult = r
|
t.imageResult.ImageResult = r
|
||||||
|
t.imageResult.ClusterId = c.clusterId
|
||||||
t.imageResult.ClusterName = c.clusterName
|
t.imageResult.ClusterName = c.clusterName
|
||||||
t.imageResult.Card = c.urls[0].Card
|
t.imageResult.Card = c.urls[0].Card
|
||||||
|
|
||||||
|
@ -352,9 +382,10 @@ func sendInferReq(images []struct {
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
idx := rand.Intn(len(c.urls))
|
idx := rand.Intn(len(c.urls))
|
||||||
r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName)
|
r, err := getInferResult(c.urls[idx].Url, t.file, t.imageResult.ImageName, c.clusterName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.imageResult.ImageResult = err.Error()
|
t.imageResult.ImageResult = err.Error()
|
||||||
|
t.imageResult.ClusterId = c.clusterId
|
||||||
t.imageResult.ClusterName = c.clusterName
|
t.imageResult.ClusterName = c.clusterName
|
||||||
t.imageResult.Card = c.urls[idx].Card
|
t.imageResult.Card = c.urls[idx].Card
|
||||||
ch <- t.imageResult
|
ch <- t.imageResult
|
||||||
|
@ -362,6 +393,7 @@ func sendInferReq(images []struct {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.imageResult.ImageResult = r
|
t.imageResult.ImageResult = r
|
||||||
|
t.imageResult.ClusterId = c.clusterId
|
||||||
t.imageResult.ClusterName = c.clusterName
|
t.imageResult.ClusterName = c.clusterName
|
||||||
t.imageResult.Card = c.urls[idx].Card
|
t.imageResult.Card = c.urls[idx].Card
|
||||||
|
|
||||||
|
@ -373,20 +405,113 @@ func sendInferReq(images []struct {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInferResult(url string, file multipart.File, fileName string) (string, error) {
|
func getInferResult(url string, file multipart.File, fileName string, clusterName string) (string, error) {
|
||||||
|
if clusterName == "鹏城云脑II-modelarts" {
|
||||||
|
r, err := getInferResultModelarts(url, file, fileName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
var res Res
|
var res Res
|
||||||
req := GetRestyRequest(10)
|
req := GetRestyRequest(10)
|
||||||
_, err := req.
|
_, err := req.
|
||||||
SetFileReader("file", fileName, file).
|
SetFileReader("file", fileName, file).
|
||||||
SetResult(&res).
|
SetResult(&res).
|
||||||
Post(url)
|
Post(url)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return res.Result, nil
|
return res.Result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getInferResultModelarts(url string, file multipart.File, fileName string) (string, error) {
|
||||||
|
var res Res
|
||||||
|
body, err := SendRequest("POST", url, file, fileName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
errjson := json.Unmarshal([]byte(body), &res)
|
||||||
|
if errjson != nil {
|
||||||
|
log.Fatalf("Error parsing JSON: %s", errjson)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.Result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignClient AK/SK签名认证
|
||||||
|
func SignClient(r *http.Request, writer *multipart.Writer) (*http.Client, error) {
|
||||||
|
r.Header.Add("content-type", "application/json;charset=UTF-8")
|
||||||
|
r.Header.Add("X-Project-Id", "d18190e28e3f45a281ef0b0696ec9d52")
|
||||||
|
r.Header.Add("x-stage", "RELEASE")
|
||||||
|
r.Header.Add("x-sdk-content-sha256", "UNSIGNED-PAYLOAD")
|
||||||
|
r.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
s := core.Signer{
|
||||||
|
Key: "UNEHPHO4Z7YSNPKRXFE4",
|
||||||
|
Secret: "JWXCE9qcYbc7RjpSRIWt4WgG3ZKF6Q4lPzkJReX9",
|
||||||
|
}
|
||||||
|
err := s.Sign(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//设置client信任所有证书
|
||||||
|
tr := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
}
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: tr,
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendRequest(method, url string, file multipart.File, fileName string) (string, error) {
|
||||||
|
/*body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)*/
|
||||||
|
// 创建一个新的缓冲区以写入multipart表单
|
||||||
|
var body bytes.Buffer
|
||||||
|
// 创建一个新的multipart writer
|
||||||
|
writer := multipart.NewWriter(&body)
|
||||||
|
// 创建一个用于写入文件的表单字段
|
||||||
|
part, err := writer.CreateFormFile("file", fileName) // "file"是表单的字段名,第二个参数是文件名
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error creating form file:", err)
|
||||||
|
}
|
||||||
|
// 将文件的内容拷贝到multipart writer中
|
||||||
|
_, err = io.Copy(part, file)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error copying file data:", err)
|
||||||
|
|
||||||
|
}
|
||||||
|
err = writer.Close()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error closing multipart writer:", err)
|
||||||
|
}
|
||||||
|
request, err := http.NewRequest(method, url, &body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error creating new request:", err)
|
||||||
|
//return nil, err
|
||||||
|
}
|
||||||
|
signedR, err := SignClient(request, writer)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error signing request:", err)
|
||||||
|
//return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := signedR.Do(request)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error sending request:", err)
|
||||||
|
//return nil, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
Resbody, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Error reading response body:", err)
|
||||||
|
//return nil, err
|
||||||
|
}
|
||||||
|
return string(Resbody), nil
|
||||||
|
}
|
||||||
|
|
||||||
func GetRestyRequest(timeoutSeconds int64) *resty.Request {
|
func GetRestyRequest(timeoutSeconds int64) *resty.Request {
|
||||||
client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second)
|
client := resty.New().SetTimeout(time.Duration(timeoutSeconds) * time.Second)
|
||||||
request := client.R()
|
request := client.R()
|
||||||
|
|
|
@ -5910,3 +5910,14 @@ type ImageResult struct {
|
||||||
Card string `json:"card"`
|
Card string `json:"card"`
|
||||||
ImageResult string `json:"imageResult"`
|
ImageResult string `json:"imageResult"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InferenceTaskDetailReq struct {
|
||||||
|
AiTaskId int64 `json:"aiTaskId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InferenceTaskDetailResp struct {
|
||||||
|
ImageName string `json:"imageName"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
Card string `json:"card"`
|
||||||
|
ClusterName string `json:"clusterName"`
|
||||||
|
}
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -7,6 +7,7 @@ toolchain go1.22.4
|
||||||
retract v0.1.20-0.20240319015239-6ae13da05255
|
retract v0.1.20-0.20240319015239-6ae13da05255
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/JCCE-nudt/apigw-go-sdk v0.0.0-20230525025609-34159d6f2818
|
||||||
github.com/Masterminds/squirrel v1.5.4
|
github.com/Masterminds/squirrel v1.5.4
|
||||||
github.com/bwmarrin/snowflake v0.3.0
|
github.com/bwmarrin/snowflake v0.3.0
|
||||||
github.com/ghodss/yaml v1.0.0
|
github.com/ghodss/yaml v1.0.0
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -38,6 +38,8 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
|
github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ=
|
||||||
|
github.com/JCCE-nudt/apigw-go-sdk v0.0.0-20230525025609-34159d6f2818 h1:QLulhUyxPDs9FFieVZwmKAnUBLeRDhsVNehotAAL/FE=
|
||||||
|
github.com/JCCE-nudt/apigw-go-sdk v0.0.0-20230525025609-34159d6f2818/go.mod h1:j+am5/1URgsvyhOAyURFR9vH3malaW7Tq6d33OyPsnM=
|
||||||
github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM=
|
github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM=
|
||||||
github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
|
github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10=
|
||||||
github.com/alecthomas/kingpin/v2 v2.4.0 h1:f48lwail6p8zpO1bC4TxtqACaGqHYA22qkHjHpqDjYY=
|
github.com/alecthomas/kingpin/v2 v2.4.0 h1:f48lwail6p8zpO1bC4TxtqACaGqHYA22qkHjHpqDjYY=
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
package models
|
||||||
|
|
||||||
|
import "github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||||
|
|
||||||
|
var _ TaskAiSubModel = (*customTaskAiSubModel)(nil)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// TaskAiSubModel is an interface to be customized, add more methods here,
|
||||||
|
// and implement the added methods in customTaskAiSubModel.
|
||||||
|
TaskAiSubModel interface {
|
||||||
|
taskAiSubModel
|
||||||
|
withSession(session sqlx.Session) TaskAiSubModel
|
||||||
|
}
|
||||||
|
|
||||||
|
customTaskAiSubModel struct {
|
||||||
|
*defaultTaskAiSubModel
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewTaskAiSubModel returns a model for the database table.
|
||||||
|
func NewTaskAiSubModel(conn sqlx.SqlConn) TaskAiSubModel {
|
||||||
|
return &customTaskAiSubModel{
|
||||||
|
defaultTaskAiSubModel: newTaskAiSubModel(conn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *customTaskAiSubModel) withSession(session sqlx.Session) TaskAiSubModel {
|
||||||
|
return NewTaskAiSubModel(sqlx.NewSqlConnFromSession(session))
|
||||||
|
}
|
|
@ -0,0 +1,88 @@
|
||||||
|
// Code generated by goctl. DO NOT EDIT.
|
||||||
|
|
||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/stores/builder"
|
||||||
|
"github.com/zeromicro/go-zero/core/stores/sqlc"
|
||||||
|
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
||||||
|
"github.com/zeromicro/go-zero/core/stringx"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
taskAiSubFieldNames = builder.RawFieldNames(&TaskAiSub{})
|
||||||
|
taskAiSubRows = strings.Join(taskAiSubFieldNames, ",")
|
||||||
|
taskAiSubRowsExpectAutoSet = strings.Join(stringx.Remove(taskAiSubFieldNames, "`id`", "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), ",")
|
||||||
|
taskAiSubRowsWithPlaceHolder = strings.Join(stringx.Remove(taskAiSubFieldNames, "`id`", "`create_at`", "`create_time`", "`created_at`", "`update_at`", "`update_time`", "`updated_at`"), "=?,") + "=?"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
taskAiSubModel interface {
|
||||||
|
Insert(ctx context.Context, data *TaskAiSub) (sql.Result, error)
|
||||||
|
FindOne(ctx context.Context, id int64) (*TaskAiSub, error)
|
||||||
|
Update(ctx context.Context, data *TaskAiSub) error
|
||||||
|
Delete(ctx context.Context, id int64) error
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultTaskAiSubModel struct {
|
||||||
|
conn sqlx.SqlConn
|
||||||
|
table string
|
||||||
|
}
|
||||||
|
|
||||||
|
TaskAiSub struct {
|
||||||
|
Id int64 `db:"id"` // id
|
||||||
|
ImageName string `db:"image_name"` // 图片名称
|
||||||
|
Result string `db:"result"` // 识别结果
|
||||||
|
Card string `db:"card"` // 加速卡
|
||||||
|
ClusterId int64 `db:"cluster_id"` // 集群id
|
||||||
|
ClusterName string `db:"cluster_name"` // 集群名称
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTaskAiSubModel(conn sqlx.SqlConn) *defaultTaskAiSubModel {
|
||||||
|
return &defaultTaskAiSubModel{
|
||||||
|
conn: conn,
|
||||||
|
table: "`task_ai_sub`",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultTaskAiSubModel) Delete(ctx context.Context, id int64) error {
|
||||||
|
query := fmt.Sprintf("delete from %s where `id` = ?", m.table)
|
||||||
|
_, err := m.conn.ExecCtx(ctx, query, id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultTaskAiSubModel) FindOne(ctx context.Context, id int64) (*TaskAiSub, error) {
|
||||||
|
query := fmt.Sprintf("select %s from %s where `id` = ? limit 1", taskAiSubRows, m.table)
|
||||||
|
var resp TaskAiSub
|
||||||
|
err := m.conn.QueryRowCtx(ctx, &resp, query, id)
|
||||||
|
switch err {
|
||||||
|
case nil:
|
||||||
|
return &resp, nil
|
||||||
|
case sqlc.ErrNotFound:
|
||||||
|
return nil, ErrNotFound
|
||||||
|
default:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultTaskAiSubModel) Insert(ctx context.Context, data *TaskAiSub) (sql.Result, error) {
|
||||||
|
query := fmt.Sprintf("insert into %s (%s) values (?, ?, ?, ?, ?)", m.table, taskAiSubRowsExpectAutoSet)
|
||||||
|
ret, err := m.conn.ExecCtx(ctx, query, data.ImageName, data.Result, data.Card, data.ClusterId, data.ClusterName)
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultTaskAiSubModel) Update(ctx context.Context, data *TaskAiSub) error {
|
||||||
|
query := fmt.Sprintf("update %s set %s where `id` = ?", m.table, taskAiSubRowsWithPlaceHolder)
|
||||||
|
_, err := m.conn.ExecCtx(ctx, query, data.ImageName, data.Result, data.Card, data.ClusterId, data.ClusterName, data.Id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *defaultTaskAiSubModel) tableName() string {
|
||||||
|
return m.table
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
// based on https://github.com/golang/go/blob/master/src/net/url/url.go
|
||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package core
|
||||||
|
|
||||||
|
func shouldEscape(c byte) bool {
|
||||||
|
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' || c == '_' || c == '-' || c == '~' || c == '.' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
func escape(s string) string {
|
||||||
|
hexCount := 0
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
c := s[i]
|
||||||
|
if shouldEscape(c) {
|
||||||
|
hexCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hexCount == 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
t := make([]byte, len(s)+2*hexCount)
|
||||||
|
j := 0
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
switch c := s[i]; {
|
||||||
|
case shouldEscape(c):
|
||||||
|
t[j] = '%'
|
||||||
|
t[j+1] = "0123456789ABCDEF"[c>>4]
|
||||||
|
t[j+2] = "0123456789ABCDEF"[c&15]
|
||||||
|
j += 3
|
||||||
|
default:
|
||||||
|
t[j] = s[i]
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(t)
|
||||||
|
}
|
|
@ -0,0 +1,201 @@
|
||||||
|
// HWS API Gateway Signature
|
||||||
|
// based on https://github.com/datastream/aws/blob/master/signv4.go
|
||||||
|
// Copyright (c) 2014, Xianjie
|
||||||
|
|
||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DateFormat = "20060102T150405Z"
|
||||||
|
SignAlgorithm = "SDK-HMAC-SHA256"
|
||||||
|
HeaderXDateTime = "X-Sdk-Date"
|
||||||
|
HeaderXHost = "host"
|
||||||
|
HeaderXAuthorization = "Authorization"
|
||||||
|
HeaderXContentSha256 = "X-Sdk-Content-Sha256"
|
||||||
|
)
|
||||||
|
|
||||||
|
func hmacsha256(keyByte []byte, dataStr string) ([]byte, error) {
|
||||||
|
hm := hmac.New(sha256.New, []byte(keyByte))
|
||||||
|
if _, err := hm.Write([]byte(dataStr)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return hm.Sum(nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a CanonicalRequest from a regular request string
|
||||||
|
func CanonicalRequest(request *http.Request, signedHeaders []string) (string, error) {
|
||||||
|
var hexencode string
|
||||||
|
var err error
|
||||||
|
if hex := request.Header.Get(HeaderXContentSha256); hex != "" {
|
||||||
|
hexencode = hex
|
||||||
|
} else {
|
||||||
|
bodyData, err := RequestPayload(request)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
hexencode, err = HexEncodeSHA256Hash(bodyData)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s", request.Method, CanonicalURI(request), CanonicalQueryString(request), CanonicalHeaders(request, signedHeaders), strings.Join(signedHeaders, ";"), hexencode), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanonicalURI returns request uri
|
||||||
|
func CanonicalURI(request *http.Request) string {
|
||||||
|
pattens := strings.Split(request.URL.Path, "/")
|
||||||
|
var uriSlice []string
|
||||||
|
for _, v := range pattens {
|
||||||
|
uriSlice = append(uriSlice, escape(v))
|
||||||
|
}
|
||||||
|
urlpath := strings.Join(uriSlice, "/")
|
||||||
|
if len(urlpath) == 0 || urlpath[len(urlpath)-1] != '/' {
|
||||||
|
urlpath = urlpath + "/"
|
||||||
|
}
|
||||||
|
return urlpath
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanonicalQueryString
|
||||||
|
func CanonicalQueryString(request *http.Request) string {
|
||||||
|
var keys []string
|
||||||
|
queryMap := request.URL.Query()
|
||||||
|
for key := range queryMap {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
var query []string
|
||||||
|
for _, key := range keys {
|
||||||
|
k := escape(key)
|
||||||
|
sort.Strings(queryMap[key])
|
||||||
|
for _, v := range queryMap[key] {
|
||||||
|
kv := fmt.Sprintf("%s=%s", k, escape(v))
|
||||||
|
query = append(query, kv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
queryStr := strings.Join(query, "&")
|
||||||
|
request.URL.RawQuery = queryStr
|
||||||
|
return queryStr
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanonicalHeaders
|
||||||
|
func CanonicalHeaders(request *http.Request, signerHeaders []string) string {
|
||||||
|
var canonicalHeaders []string
|
||||||
|
header := make(map[string][]string)
|
||||||
|
for k, v := range request.Header {
|
||||||
|
header[strings.ToLower(k)] = v
|
||||||
|
}
|
||||||
|
for _, key := range signerHeaders {
|
||||||
|
value := header[key]
|
||||||
|
if strings.EqualFold(key, HeaderXHost) {
|
||||||
|
value = []string{request.Host}
|
||||||
|
}
|
||||||
|
sort.Strings(value)
|
||||||
|
for _, v := range value {
|
||||||
|
canonicalHeaders = append(canonicalHeaders, key+":"+strings.TrimSpace(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s\n", strings.Join(canonicalHeaders, "\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignedHeaders
|
||||||
|
func SignedHeaders(r *http.Request) []string {
|
||||||
|
var signedHeaders []string
|
||||||
|
for key := range r.Header {
|
||||||
|
signedHeaders = append(signedHeaders, strings.ToLower(key))
|
||||||
|
}
|
||||||
|
sort.Strings(signedHeaders)
|
||||||
|
return signedHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestPayload
|
||||||
|
func RequestPayload(request *http.Request) ([]byte, error) {
|
||||||
|
if request.Body == nil {
|
||||||
|
return []byte(""), nil
|
||||||
|
}
|
||||||
|
bodyByte, err := ioutil.ReadAll(request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return []byte(""), err
|
||||||
|
}
|
||||||
|
request.Body = ioutil.NopCloser(bytes.NewBuffer(bodyByte))
|
||||||
|
return bodyByte, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a "String to Sign".
|
||||||
|
func StringToSign(canonicalRequest string, t time.Time) (string, error) {
|
||||||
|
hashStruct := sha256.New()
|
||||||
|
_, err := hashStruct.Write([]byte(canonicalRequest))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s\n%s\n%x",
|
||||||
|
SignAlgorithm, t.UTC().Format(DateFormat), hashStruct.Sum(nil)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the HWS Signature.
|
||||||
|
func SignStringToSign(stringToSign string, signingKey []byte) (string, error) {
|
||||||
|
hmsha, err := hmacsha256(signingKey, stringToSign)
|
||||||
|
return fmt.Sprintf("%x", hmsha), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// HexEncodeSHA256Hash returns hexcode of sha256
|
||||||
|
func HexEncodeSHA256Hash(body []byte) (string, error) {
|
||||||
|
hashStruct := sha256.New()
|
||||||
|
if len(body) == 0 {
|
||||||
|
body = []byte("")
|
||||||
|
}
|
||||||
|
_, err := hashStruct.Write(body)
|
||||||
|
return fmt.Sprintf("%x", hashStruct.Sum(nil)), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the finalized value for the "Authorization" header. The signature parameter is the output from SignStringToSign
|
||||||
|
func AuthHeaderValue(signatureStr, accessKeyStr string, signedHeaders []string) string {
|
||||||
|
return fmt.Sprintf("%s Access=%s, SignedHeaders=%s, Signature=%s", SignAlgorithm, accessKeyStr, strings.Join(signedHeaders, ";"), signatureStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signature HWS meta
|
||||||
|
type Signer struct {
|
||||||
|
Key string
|
||||||
|
Secret string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignRequest set Authorization header
|
||||||
|
func (s *Signer) Sign(request *http.Request) error {
|
||||||
|
var t time.Time
|
||||||
|
var err error
|
||||||
|
var date string
|
||||||
|
if date = request.Header.Get(HeaderXDateTime); date != "" {
|
||||||
|
t, err = time.Parse(DateFormat, date)
|
||||||
|
}
|
||||||
|
if err != nil || date == "" {
|
||||||
|
t = time.Now()
|
||||||
|
request.Header.Set(HeaderXDateTime, t.UTC().Format(DateFormat))
|
||||||
|
}
|
||||||
|
signedHeaders := SignedHeaders(request)
|
||||||
|
canonicalRequest, err := CanonicalRequest(request, signedHeaders)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
stringToSignStr, err := StringToSign(canonicalRequest, t)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
signatureStr, err := SignStringToSign(stringToSignStr, []byte(s.Secret))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
authValueStr := AuthHeaderValue(signatureStr, s.Key, signedHeaders)
|
||||||
|
request.Header.Set(HeaderXAuthorization, authValueStr)
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -2,6 +2,9 @@
|
||||||
## explicit; go 1.20
|
## explicit; go 1.20
|
||||||
filippo.io/edwards25519
|
filippo.io/edwards25519
|
||||||
filippo.io/edwards25519/field
|
filippo.io/edwards25519/field
|
||||||
|
# github.com/JCCE-nudt/apigw-go-sdk v0.0.0-20230525025609-34159d6f2818
|
||||||
|
## explicit; go 1.17
|
||||||
|
github.com/JCCE-nudt/apigw-go-sdk/core
|
||||||
# github.com/Masterminds/squirrel v1.5.4
|
# github.com/Masterminds/squirrel v1.5.4
|
||||||
## explicit; go 1.14
|
## explicit; go 1.14
|
||||||
github.com/Masterminds/squirrel
|
github.com/Masterminds/squirrel
|
||||||
|
|
Loading…
Reference in New Issue